In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

import nn_modules as nnm
import nn_preprocessor as nnp


# import filestructure as fs


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Define the layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.fc1 = nn.Linear(
            32 * 122, 128
        )  # Calculate the output size based on the input dimension
        self.fc2 = nn.Linear(
            128, 2
        )  # Output layer dimension is 2 for binary classification

    def forward(self, x):
        # Define the forward pass
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 122)  # Reshape before passing to fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
hyp = nnp.Hyperparams(resume=True, output_path="2024-02-20_10.55.11.660910")

In [None]:
model, criterion, optimizer, scheduler = nnm.load_checkpoint(hyp)

In [None]:
hyp = nnp.Hyperparams(
    batch_size=10_000,
    learning_rate=0.001,
    patience=30,
    epochs=100,
    eval_every=2,
)

labels = nnp.Labels("mz")
grouping = nnp.Grouping("mz", group_init="person")

custom_dataset = nnp.CustomDataset(
    "mz", labels.tissue_type, grouping.result, transpose=True
)
custom_dataset.pre_transforms(
    transform=transforms.Compose([nnp.ColPadding(custom_dataset.all_cols)])
)

splitter = nnp.DatasetSplitter(custom_dataset, 0.7, 0.15)
train_dataset, val_dataset, test_dataset = splitter.group_split()

train_loader = DataLoader(train_dataset, batch_size=hyp.batch_size, shuffle=hyp.shuffle)
val_loader = DataLoader(val_dataset, batch_size=hyp.batch_size, shuffle=hyp.shuffle)
test_loader = DataLoader(test_dataset, batch_size=hyp.batch_size, shuffle=hyp.shuffle)

In [None]:
len(train_loader)

In [None]:
labels = nnp.Labels("mz")
grouping = nnp.Grouping("mz", group_init="person")

cdata = nnp.CustomDataset("mz", labels.tissue_type, grouping.result, transpose=True)

df = cdata.to_df()
bdf = cdata.df_data

In [None]:
data = bdf.loc[bdf["groups"] == 27].iloc[0].pkl
data.columns

In [None]:
len(data.index)