In [None]:
# Imports
from simple_mamba import all_at_the_same_time
from Dataloader import list_directory_tree, DatasetAllData
import torch.optim as optim

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np

In [None]:
model = all_at_the_same_time(d_model=6, d_state=16, d_conv=2, expand=1, dropout=0.1)

In [None]:
class MSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, pred, actual):
        return self.mse(torch.log(pred + 1), torch.log(actual + 1))

In [None]:
# Set up the optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

# Define the loss functions
alpha_criterion = torch.nn.L1Loss().to("cuda")
K_criterion = MSLELoss().to("cuda")

# Create a list of all data files
all_data_set = list_directory_tree(r"../../data/datasets")
np.random.shuffle(all_data_set)

# Create training and test datasets
training_dataset = DatasetAllData(all_data_set[:10000], transform=False, pad=(20, 200))
test_dataset = DatasetAllData(all_data_set[-100:], transform=False, pad=(20, 200))

# Create dataloaders for training and testing
dataloader = DataLoader(training_dataset, shuffle=True, batch_size=10, num_workers=5)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=10, num_workers=5)

In [None]:
def compute_test_loss(model, test_dataloader, alpha_criterion, K_criterion):
    model.eval()
    test_classification_loss = []
    test_alpha_loss = []
    test_K_loss = []

    with torch.no_grad():
        for inputs, classification_targets, K_targets, alpha_targets in test_dataloader:
            inputs = inputs.to("cuda", dtype=torch.float32)
            inputs = torch.flatten(inputs, start_dim=0, end_dim=1)

            classification_targets = classification_targets.type(torch.LongTensor).to(
                "cuda"
            )
            classification_targets = torch.flatten(
                classification_targets, start_dim=0, end_dim=1
            )

            classification_output, alpha_output, K_output = model(inputs)

            classification_output = torch.squeeze(classification_output)

            counts = torch.unique(classification_targets, return_counts=True)[1][1:]
            weights = torch.sum(counts) / (2 * counts)
            weights = weights.to("cpu", dtype=torch.float32)
            weight = torch.zeros(3, dtype=torch.float32)
            weight[1:] = weights

            classification_criterion = nn.CrossEntropyLoss(
                weight=weight, ignore_index=0
            )

            classification_loss = classification_criterion(
                classification_output.view(-1, 3).to("cpu", dtype=torch.float32),
                classification_targets.view(-1).to("cpu"),
            )
            alpha_targets = alpha_targets.to("cuda", dtype=torch.float32)
            alpha_targets = torch.flatten(alpha_targets, start_dim=0, end_dim=1)
            alpha_loss = alpha_criterion(alpha_output, alpha_targets)
            K_targets = K_targets.to("cuda", dtype=torch.float32)
            K_targets = torch.flatten(K_targets, start_dim=0, end_dim=1)
            K_loss = K_criterion(K_output, K_targets)

            test_classification_loss.append(classification_loss.item())
            test_alpha_loss.append(alpha_loss.item())
            test_K_loss.append(K_loss.item())

    return (
        np.mean(test_classification_loss),
        np.mean(test_alpha_loss),
        np.mean(test_K_loss),
    )

In [None]:
max_epoch = 50
total_classification_loss = []
total_K_loss = []
total_alpha_loss = []
test_classification_loss = []
test_K_loss = []
test_alpha_loss = []

for epoch in range(max_epoch):
    running_classification_loss = []
    running_alpha_loss = []
    running_K_loss = []

    with tqdm(dataloader, unit="batch") as tepoch:
        model.train()

        for inputs, classification_targets, K_targets, alpha_targets in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            inputs = inputs.to("cuda", dtype=torch.float32)
            inputs = torch.flatten(inputs, start_dim=0, end_dim=1)

            classification_targets = classification_targets.type(torch.LongTensor).to(
                "cuda"
            )
            classification_targets = torch.flatten(
                classification_targets, start_dim=0, end_dim=1
            )

            optimizer.zero_grad()

            classification_output, alpha_output, K_output = model(inputs)
            classification_output = torch.squeeze(classification_output)

            counts = torch.unique(classification_targets, return_counts=True)[1][1:]
            weights = torch.sum(counts) / (2 * counts)
            weights = weights.to("cpu", dtype=torch.float32)
            weight = torch.zeros(3, dtype=torch.float32)
            weight[1:] = weights

            classification_criterion = nn.CrossEntropyLoss(
                weight=weight, ignore_index=0
            )

            classification_loss = classification_criterion(
                classification_output.view(-1, 3).to("cpu", dtype=torch.float32),
                classification_targets.view(-1).to("cpu"),
            )
            alpha_targets = alpha_targets.to("cuda", dtype=torch.float32)
            alpha_targets = torch.flatten(alpha_targets, start_dim=0, end_dim=1)
            alpha_loss = alpha_criterion(alpha_output, alpha_targets)
            K_targets = K_targets.to("cuda", dtype=torch.float32)
            K_targets = torch.flatten(K_targets, start_dim=0, end_dim=1)
            K_loss = K_criterion(K_output, K_targets)

            total_loss = alpha_loss + K_loss + classification_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            tepoch.set_postfix(
                loss_c=classification_loss.item(),
                loss_a=alpha_loss.item(),
                loss_K=K_loss.item(),
            )

            running_classification_loss.append(classification_loss.item())
            running_alpha_loss.append(alpha_loss.item())
            running_K_loss.append(K_loss.item())

        runnin_test_class_loss, runnin_test_alpha_loss, runnin_test_K_loss = (
            compute_test_loss(model, test_dataloader, alpha_criterion, K_criterion)
        )
        total_classification_loss.append(np.mean(running_classification_loss))
        total_alpha_loss.append(np.mean(running_alpha_loss))
        total_K_loss.append(np.mean(running_K_loss))

        test_classification_loss.append(runnin_test_class_loss)
        test_K_loss.append(runnin_test_alpha_loss)
        test_alpha_loss.append(runnin_test_K_loss)

        torch.save(
            model.state_dict(),
            f"6_features_saved_state/100k_files_training_new_bimamba_epoch_{epoch}",
        )

        plt.figure(dpi=300)
        plt.semilogy(total_alpha_loss, label="alpha")
        plt.semilogy(total_classification_loss, label="classification")
        plt.semilogy(total_K_loss, label="k")
        plt.semilogy(test_alpha_loss, "--", label="test")
        plt.semilogy(test_classification_loss, "--", label="test")
        plt.semilogy(test_K_loss, "--", label="test")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.legend()
        plt.show()

In [None]:
plt.figure(dpi=300)
plt.semilogy(total_alpha_loss[90:], label="alpha")
plt.semilogy(total_classification_loss[90:], label="classification")
plt.semilogy(total_K_loss[90:], label="k")

plt.semilogy(test_alpha_loss[90:], "--", label="test")
plt.semilogy(test_classification_loss[90:], "--", label="test")
plt.semilogy(test_K_loss[90:], "--", label="test")
plt.ylim((1e-2, 1))
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.savefig("1000epoch18kfiles.pdf")

In [None]:
plt.figure(dpi=300)
plt.loglog(total_classification_loss, label="classification")
plt.loglog(test_classification_loss, "--", label="test")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.savefig("1000epoch18kfiles.pdf")