In [1]:
import enum
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
from PIL import Image
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
class Resnet50Encoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = models.resnet50()
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), x.size(1))
        return x

In [19]:
class VGG16Encoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = models.vgg16()
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), 512 * 7 * 7)
        return x

In [3]:
class ImageAggregator(nn.Module):

    def __init__(
        self,
        d_model=2048,
        aggregator=lambda x: torch.mean(x, dim=-2),
        num_classes=2
    ):
        super().__init__()
        self.d_model = d_model
        self.aggregator = aggregator
        self.head = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.aggregator(x)
        x = self.head(x)
        return x

In [4]:
def image_to_numpy(image_path):
    image_pil = Image.open(image_path)
    image_np = np.array(image_pil, dtype=np.float32) / 255.0
    return image_np

In [5]:
class PatientImageDataset(Dataset):

    def __init__(self, 
        patient_path, 
        transform_function=None,
    ):
        super().__init__()
        self.patient_path = patient_path
        self.transform_function = transform_function

        images_path = []
        for image_path in os.listdir(patient_path):
            images_path.append(os.path.join(patient_path, image_path))

        self.images_path = images_path

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, index):
        image_path = self.images_path[index]
        image = image_to_numpy(image_path)
        if self.transform_function is not None:
            image = self.transform_function(image)
        return image

In [6]:
class PatientsDataset(Dataset):

    def __init__(
        self, 
        split_dir, 
        patients_ids,
        patients_labels
    ):
        super().__init__()
        self.split_dir = split_dir
        self.patients_ids = patients_ids
        self.patients_labels = patients_labels

    def __len__(self):
        return len(self.patients_ids)

    def __getitem__(self, index):
        patient_id = self.patients_ids[index]
        patient_label = self.patients_labels[index]
        patient_path = os.path.join(self.split_dir, patient_id)
        return {
            "id": patient_id, 
            "label": patient_label, 
            "patient_path": patient_path
        }

In [7]:
data_dir = "/kaggle/input/dlmi-lymphocytosis/dlmi-lymphocytosis-classification/"
#data_dir = "../../data/dlmi-lymphocytosis-classification/"
train_dir = data_dir + "trainset/"
test_dir = data_dir + "testset/"
save_dir = "/kaggle/working/"

In [8]:
trainset_true_df = pd.read_csv(data_dir + "trainset/trainset_true.csv")
testset_data_df = pd.read_csv(data_dir + "testset/testset_data.csv")

In [9]:
id_train_val = trainset_true_df["ID"].to_numpy()
labels = trainset_true_df["LABEL"].to_numpy()
labels_dict = dict(zip(id_train_val, labels))

id_train_val_0 = id_train_val[labels == 0]
id_train_val_1 = id_train_val[labels == 1]

n_id_val_per_class = 10
id_val_0 = np.random.choice(id_train_val_0, size=n_id_val_per_class, replace=False)
id_val_1 = np.random.choice(id_train_val_1, size=n_id_val_per_class, replace=False)

id_val = np.concatenate([id_val_0, id_val_1])
mask = np.isin(id_train_val, id_val, invert=True)
id_train = id_train_val[mask]

id_train.shape, id_val.shape

((143,), (20,))

In [10]:
batch_size = 1
seq_len = 64

In [11]:
train_dataset = PatientsDataset(
    split_dir=train_dir,
    patients_ids=id_train,
    patients_labels=np.array([labels_dict[id] for id in id_train]),
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = PatientsDataset(
    split_dir=train_dir,
    patients_ids=id_val,
    patients_labels=np.array([labels_dict[id] for id in id_val]),
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_val_dataset = PatientsDataset(
    split_dir=train_dir,
    patients_ids=id_train_val,
    patients_labels=labels,
)
train_val_dataloader = DataLoader(train_val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = PatientsDataset(
    split_dir=test_dir,
    patients_ids=testset_data_df["ID"].to_numpy(),
    patients_labels=[-1] * len(trainset_true_df["ID"]),
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [12]:
def train(
    model,
    image_encoder,
    optimizer,
    criterion,
    dataloader,
    seq_len,
    device
):
    total_loss = 0
    all_preds = []
    all_true = []

    for data in tqdm(dataloader, "Train"):

        patient_label_item = data["label"].item()
        patient_label = data["label"].to(device)
        batch_preds = []
        patient_loss = 0

        patient_path = data["patient_path"][0]
        patient_dataset = PatientImageDataset(
            patient_path=patient_path,
            transform_function=transform_function,
        )
        patient_dataloader = DataLoader(patient_dataset, batch_size=seq_len, shuffle=True)

        for batch_images in patient_dataloader:
            batch_images = batch_images.unsqueeze(0).to(device)
            
            with torch.no_grad():
                batch_size, seq_len, im_chan, im_width, im_height = batch_images.size()
                x = batch_images.view(batch_size * seq_len, im_chan, im_width, im_height)
                x = image_encoder(x)
                x = x.view(batch_size, seq_len, model.d_model)

            logits = model(x)
            loss = criterion(logits, patient_label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            patient_loss += loss.item()
            preds = logits.argmax(dim=-1).cpu().detach().numpy()
            batch_preds.extend(preds)

        patient_loss /= len(patient_dataset)
        total_loss += patient_loss
        patient_pred = int(np.mean(batch_preds) > 0.5)
        all_preds.append(patient_pred)
        all_true.append(patient_label_item)

    total_loss /= len(dataloader.dataset)
    acc = metrics.accuracy_score(all_true, all_preds)
    bacc = metrics.balanced_accuracy_score(all_true, all_preds)
    f1 = metrics.f1_score(all_true, all_preds)

    return acc, bacc, f1, total_loss

In [13]:
def test(
    model,
    image_encoder,
    criterion,
    dataloader,
    seq_len,
    device
):
    total_loss = 0
    all_preds = []
    all_true = []

    with torch.no_grad():
        for data in tqdm(dataloader, "Test"):

            patient_label_item = data["label"].item()
            patient_label = data["label"].to(device)
            batch_preds = []
            patient_loss = 0

            patient_path = data["patient_path"][0]
            patient_dataset = PatientImageDataset(
                patient_path=patient_path,
                transform_function=transform_function,
            )
            patient_dataloader = DataLoader(patient_dataset, batch_size=seq_len, shuffle=False)

            for batch_images in patient_dataloader:
                batch_images = batch_images.unsqueeze(0).to(device)

                batch_size, seq_len, im_chan, im_width, im_height = batch_images.size()
                x = batch_images.view(batch_size * seq_len, im_chan, im_width, im_height)
                x = image_encoder(x)
                x = x.view(batch_size, seq_len, model.d_model)

                logits = model(x)
                loss = criterion(logits, patient_label)

                patient_loss += loss.item()
                preds = logits.argmax(dim=-1).cpu().detach().numpy()
                batch_preds.extend(preds)

            patient_loss /= len(patient_dataset)
            total_loss += patient_loss
            patient_pred = int(np.mean(batch_preds) > 0.5)
            all_preds.append(patient_pred)
            all_true.append(patient_label_item)

    total_loss /= len(dataloader.dataset)
    acc = metrics.accuracy_score(all_true, all_preds)
    bacc = metrics.balanced_accuracy_score(all_true, all_preds)
    f1 = metrics.f1_score(all_true, all_preds)

    return acc, bacc, f1, total_loss

In [14]:
def predict(
    model,
    image_encoder,
    dataloader,
    seq_len,
    device
):
    all_ids = []
    all_preds = []

    with torch.no_grad():
        for data in tqdm(dataloader, "Predict"):

            patient_id = data["id"][0]
            batch_preds = []

            patient_path = data["patient_path"][0]
            patient_dataset = PatientImageDataset(
                patient_path=patient_path,
                transform_function=transform_function,
            )
            patient_dataloader = DataLoader(patient_dataset, batch_size=seq_len, shuffle=False)

            for batch_images in patient_dataloader:
                batch_images = batch_images.unsqueeze(0).to(device)
                
                batch_size, seq_len, im_chan, im_width, im_height = batch_images.size()
                x = batch_images.view(batch_size * seq_len, im_chan, im_width, im_height)
                x = image_encoder(x)
                x = x.view(batch_size, seq_len, model.d_model)

                logits = model(x)
                preds = logits.argmax(dim=-1).cpu().detach().numpy()
                batch_preds.extend(preds)

            patient_pred = int(np.mean(batch_preds) > 0.5)
            all_preds.append(patient_pred)
            all_ids.append(patient_id)

    return all_ids, all_preds

In [15]:
def trainer(
    model, 
    image_encoder,
    optimizer, 
    criterion, 
    train_dataloader, 
    test_dataloader, 
    n_epochs,
    seq_len,
    device,
    model_name
):
    train_accs      = []
    train_baccs     = []
    train_f1s       = []
    train_losses    = []
    test_accs       = []
    test_baccs      = []
    test_f1s        = []
    test_losses     = []

    for epoch in range(1, n_epochs + 1):
        train_acc, train_bacc, train_f1, train_loss = train(
            model, image_encoder, optimizer, criterion, train_dataloader, seq_len, device
        )
        test_acc, test_bacc, test_f1, test_loss = test(
            model, image_encoder, criterion, test_dataloader, seq_len, device
        )

        print(
            f"[Epoch {epoch} / {n_epochs}]",
            f"\ttrain loss = {train_loss:.4f} acc = {train_acc:.4f} bacc = {train_bacc:.4f} f1 = {train_f1:.4f}",
            f"\ttest loss = {test_loss:.4f} acc = {test_acc:.4f} bacc = {test_bacc:.4f} f1 = {test_f1:.4f}", 
            sep="\n"
        )

        train_accs.append(train_acc)
        train_baccs.append(train_bacc)
        train_f1s.append(train_f1)
        train_losses.append(train_loss)

        test_accs.append(test_acc)
        test_baccs.append(test_bacc)
        test_f1s.append(test_f1)
        test_losses.append(test_loss)

        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, 
            os.path.join(save_dir, f"{model_name}.pt")
        )

    plt.figure()
    plt.plot(train_accs, label="Train")
    plt.plot(test_accs, label="Test")
    plt.title("Accuracy")
    plt.show()

    plt.figure()
    plt.plot(train_baccs, label="Train")
    plt.plot(test_baccs, label="Test")
    plt.title("Balanced accuracy")
    plt.show()

    plt.figure()
    plt.plot(train_f1s, label="Train")
    plt.plot(test_f1s, label="Test")
    plt.title("F1 Score")
    plt.show()

    plt.figure()
    plt.plot(train_losses, label="Train")
    plt.plot(test_losses, label="Test")
    plt.title("Loss")
    plt.show()
    

In [None]:
def load_checkpoint(checkpoint_path, model_args={}, optimizer_args={}):
    model = ImageAggregator(**model_args)
    optimizer = optim.Adam(model.parameters(), **optimizer_args)

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    
    return model, optimizer, epoch

### Mean

In [23]:
lr = 1e-1
n_epochs = 100

image_encoder = VGG16Encoder() #Resnet50Encoder()
d_model = 25_088 #2048

model = ImageAggregator(
    d_model=d_model,
    aggregator=lambda x: torch.mean(x, dim=-2)
)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction="sum")
transform_function = lambda x: x.reshape(3, x.shape[0], x.shape[1])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_encoder.to(device)
model.to(device)

print(device)

cuda


In [None]:
trainer(
    model,
    image_encoder,
    optimizer,
    criterion,
    train_dataloader,
    val_dataloader,
    n_epochs,
    seq_len,
    device,
    "mean_model"
)

Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


[Epoch 1 / 100]
	train loss = 1697.3997 acc = 0.9860 bacc = 0.9826 f1 = 0.9903
	test loss = 2.5464 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.10it/s]


[Epoch 2 / 100]
	train loss = 0.5024 acc = 0.9371 bacc = 0.9410 f1 = 0.9552
	test loss = 2.8605 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:24<00:00,  1.70it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


[Epoch 3 / 100]
	train loss = 0.4848 acc = 0.9580 bacc = 0.9479 f1 = 0.9709
	test loss = 2.5436 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


[Epoch 4 / 100]
	train loss = 0.4341 acc = 0.9650 bacc = 0.9681 f1 = 0.9754
	test loss = 2.4369 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


[Epoch 5 / 100]
	train loss = 0.4319 acc = 0.9510 bacc = 0.9354 f1 = 0.9662
	test loss = 2.7946 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


[Epoch 6 / 100]
	train loss = 0.4430 acc = 0.9650 bacc = 0.9528 f1 = 0.9758
	test loss = 2.3066 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


[Epoch 7 / 100]
	train loss = 0.5064 acc = 0.9441 bacc = 0.9229 f1 = 0.9615
	test loss = 2.0477 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


[Epoch 8 / 100]
	train loss = 0.4475 acc = 0.9720 bacc = 0.9653 f1 = 0.9806
	test loss = 2.5424 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


[Epoch 9 / 100]
	train loss = 0.4428 acc = 0.9720 bacc = 0.9576 f1 = 0.9808
	test loss = 3.2019 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


[Epoch 10 / 100]
	train loss = 0.4642 acc = 0.9441 bacc = 0.9382 f1 = 0.9608
	test loss = 2.6954 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


[Epoch 11 / 100]
	train loss = 0.4446 acc = 0.9580 bacc = 0.9326 f1 = 0.9714
	test loss = 1.8475 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


[Epoch 12 / 100]
	train loss = 0.4273 acc = 0.9720 bacc = 0.9653 f1 = 0.9806
	test loss = 2.7039 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


[Epoch 13 / 100]
	train loss = 0.5267 acc = 0.9510 bacc = 0.9354 f1 = 0.9662
	test loss = 2.7863 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


[Epoch 14 / 100]
	train loss = 0.4633 acc = 0.9790 bacc = 0.9701 f1 = 0.9855
	test loss = 2.7281 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


[Epoch 15 / 100]
	train loss = 0.4871 acc = 0.9650 bacc = 0.9604 f1 = 0.9756
	test loss = 3.0446 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 16 / 100]
	train loss = 0.5188 acc = 0.9441 bacc = 0.9459 f1 = 0.9604
	test loss = 2.1936 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.06it/s]


[Epoch 17 / 100]
	train loss = 0.4403 acc = 0.9580 bacc = 0.9403 f1 = 0.9712
	test loss = 3.0610 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:22<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


[Epoch 18 / 100]
	train loss = 0.5048 acc = 0.9301 bacc = 0.9209 f1 = 0.9510
	test loss = 2.6532 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


[Epoch 19 / 100]
	train loss = 0.4420 acc = 0.9790 bacc = 0.9625 f1 = 0.9856
	test loss = 2.5060 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


[Epoch 20 / 100]
	train loss = 0.4730 acc = 0.9650 bacc = 0.9681 f1 = 0.9754
	test loss = 2.8768 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


[Epoch 21 / 100]
	train loss = 0.5203 acc = 0.9510 bacc = 0.9584 f1 = 0.9652
	test loss = 2.2180 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


[Epoch 22 / 100]
	train loss = 0.4810 acc = 0.9720 bacc = 0.9729 f1 = 0.9804
	test loss = 2.3682 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


[Epoch 23 / 100]
	train loss = 0.3962 acc = 0.9790 bacc = 0.9701 f1 = 0.9855
	test loss = 2.8710 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:23<00:00,  1.70it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


[Epoch 24 / 100]
	train loss = 0.4948 acc = 0.9441 bacc = 0.9459 f1 = 0.9604
	test loss = 2.4006 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


[Epoch 25 / 100]
	train loss = 0.4236 acc = 0.9650 bacc = 0.9604 f1 = 0.9756
	test loss = 2.4457 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


[Epoch 26 / 100]
	train loss = 0.4947 acc = 0.9510 bacc = 0.9354 f1 = 0.9662
	test loss = 2.7054 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


[Epoch 27 / 100]
	train loss = 0.4225 acc = 0.9580 bacc = 0.9479 f1 = 0.9709
	test loss = 2.4487 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


[Epoch 28 / 100]
	train loss = 0.4134 acc = 0.9650 bacc = 0.9451 f1 = 0.9761
	test loss = 2.6646 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


[Epoch 29 / 100]
	train loss = 0.4459 acc = 0.9720 bacc = 0.9576 f1 = 0.9808
	test loss = 2.3925 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


[Epoch 30 / 100]
	train loss = 0.4265 acc = 0.9510 bacc = 0.9278 f1 = 0.9665
	test loss = 2.2305 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 31 / 100]
	train loss = 0.3947 acc = 0.9790 bacc = 0.9701 f1 = 0.9855
	test loss = 3.5768 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:22<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


[Epoch 32 / 100]
	train loss = 0.4374 acc = 0.9650 bacc = 0.9681 f1 = 0.9754
	test loss = 2.5120 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


[Epoch 33 / 100]
	train loss = 0.4471 acc = 0.9510 bacc = 0.9354 f1 = 0.9662
	test loss = 2.5942 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


[Epoch 34 / 100]
	train loss = 0.4564 acc = 0.9790 bacc = 0.9701 f1 = 0.9855
	test loss = 2.2418 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 35 / 100]
	train loss = 0.4151 acc = 0.9231 bacc = 0.8854 f1 = 0.9479
	test loss = 2.7636 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:20<00:00,  1.78it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 36 / 100]
	train loss = 0.4402 acc = 0.9650 bacc = 0.9528 f1 = 0.9758
	test loss = 2.7438 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:19<00:00,  1.79it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


[Epoch 37 / 100]
	train loss = 0.4479 acc = 0.9930 bacc = 0.9875 f1 = 0.9952
	test loss = 2.1628 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:20<00:00,  1.79it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 38 / 100]
	train loss = 0.4361 acc = 0.9441 bacc = 0.9153 f1 = 0.9619
	test loss = 2.6210 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


[Epoch 39 / 100]
	train loss = 0.4642 acc = 0.9650 bacc = 0.9528 f1 = 0.9758
	test loss = 2.3119 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:20<00:00,  1.77it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 40 / 100]
	train loss = 0.4606 acc = 0.9720 bacc = 0.9576 f1 = 0.9808
	test loss = 3.4343 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:20<00:00,  1.78it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 41 / 100]
	train loss = 0.5366 acc = 0.9441 bacc = 0.9306 f1 = 0.9612
	test loss = 2.6941 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:20<00:00,  1.78it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 42 / 100]
	train loss = 0.4356 acc = 0.9441 bacc = 0.9382 f1 = 0.9608
	test loss = 2.7094 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 43 / 100]
	train loss = 0.4798 acc = 0.9441 bacc = 0.9382 f1 = 0.9608
	test loss = 2.8359 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


[Epoch 44 / 100]
	train loss = 0.4409 acc = 0.9860 bacc = 0.9826 f1 = 0.9903
	test loss = 2.7120 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:21<00:00,  1.75it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


[Epoch 45 / 100]
	train loss = 0.3862 acc = 0.9930 bacc = 0.9951 f1 = 0.9951
	test loss = 2.4350 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 46 / 100]
	train loss = 0.5097 acc = 0.9790 bacc = 0.9778 f1 = 0.9854
	test loss = 2.2007 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:20<00:00,  1.77it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


[Epoch 47 / 100]
	train loss = 0.4594 acc = 0.9720 bacc = 0.9729 f1 = 0.9804
	test loss = 2.4715 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


[Epoch 48 / 100]
	train loss = 0.4998 acc = 0.9650 bacc = 0.9681 f1 = 0.9754
	test loss = 2.5620 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:20<00:00,  1.77it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


[Epoch 49 / 100]
	train loss = 0.3262 acc = 0.9790 bacc = 0.9701 f1 = 0.9855
	test loss = 3.7152 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:20<00:00,  1.77it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


[Epoch 50 / 100]
	train loss = 0.4684 acc = 0.9510 bacc = 0.9584 f1 = 0.9652
	test loss = 2.3259 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.75it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 51 / 100]
	train loss = 0.4467 acc = 0.9510 bacc = 0.9278 f1 = 0.9665
	test loss = 4.0801 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:21<00:00,  1.75it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


[Epoch 52 / 100]
	train loss = 0.4747 acc = 0.9580 bacc = 0.9556 f1 = 0.9706
	test loss = 2.4761 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 53 / 100]
	train loss = 0.4260 acc = 0.9720 bacc = 0.9653 f1 = 0.9806
	test loss = 3.7109 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 54 / 100]
	train loss = 0.4652 acc = 0.9510 bacc = 0.9507 f1 = 0.9655
	test loss = 3.9080 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:20<00:00,  1.77it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


[Epoch 55 / 100]
	train loss = 0.4819 acc = 0.9510 bacc = 0.9354 f1 = 0.9662
	test loss = 2.6332 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.71it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


[Epoch 56 / 100]
	train loss = 0.4851 acc = 0.9510 bacc = 0.9278 f1 = 0.9665
	test loss = 2.2549 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


[Epoch 57 / 100]
	train loss = 0.4106 acc = 0.9720 bacc = 0.9576 f1 = 0.9808
	test loss = 2.9511 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


[Epoch 58 / 100]
	train loss = 0.5534 acc = 0.9301 bacc = 0.9132 f1 = 0.9515
	test loss = 2.4405 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


[Epoch 59 / 100]
	train loss = 0.4434 acc = 0.9650 bacc = 0.9451 f1 = 0.9761
	test loss = 2.4264 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


[Epoch 60 / 100]
	train loss = 0.4842 acc = 0.9650 bacc = 0.9528 f1 = 0.9758
	test loss = 2.6502 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:22<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


[Epoch 61 / 100]
	train loss = 0.5118 acc = 0.9580 bacc = 0.9403 f1 = 0.9712
	test loss = 2.9070 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:22<00:00,  1.74it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


[Epoch 62 / 100]
	train loss = 0.5479 acc = 0.9301 bacc = 0.9362 f1 = 0.9500
	test loss = 2.3963 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:23<00:00,  1.72it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.10it/s]


[Epoch 63 / 100]
	train loss = 0.4933 acc = 0.9790 bacc = 0.9701 f1 = 0.9855
	test loss = 2.1397 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.75it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


[Epoch 64 / 100]
	train loss = 0.4696 acc = 0.9301 bacc = 0.8979 f1 = 0.9524
	test loss = 2.3949 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.75it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


[Epoch 65 / 100]
	train loss = 0.4591 acc = 0.9650 bacc = 0.9451 f1 = 0.9761
	test loss = 2.6324 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:20<00:00,  1.77it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


[Epoch 66 / 100]
	train loss = 0.3741 acc = 0.9650 bacc = 0.9528 f1 = 0.9758
	test loss = 3.4820 acc = 0.5000 bacc = 0.5000 f1 = 0.0000


Train: 100%|██████████| 143/143 [01:20<00:00,  1.78it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


[Epoch 67 / 100]
	train loss = 0.4313 acc = 0.9580 bacc = 0.9632 f1 = 0.9703
	test loss = 2.7955 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.75it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


[Epoch 68 / 100]
	train loss = 0.4348 acc = 0.9790 bacc = 0.9625 f1 = 0.9856
	test loss = 2.3802 acc = 0.5000 bacc = 0.5000 f1 = 0.6667


Train: 100%|██████████| 143/143 [01:21<00:00,  1.76it/s]
Test: 100%|██████████| 20/20 [00:09<00:00,  1.72it/s]

In [None]:
all_ids, all_preds = predict(model, image_encoder, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, "submission_mean_01.csv"), index=False)

In [None]:
trainer(
    model,
    image_encoder,
    optimizer,
    criterion,
    val_dataloader,
    val_dataloader,
    n_epochs,
    seq_len,
    device,
    "mean_model"
)

In [None]:
all_ids, all_preds = predict(model, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, "submission_mean_02.csv"), index=False)

### Max

In [None]:
lr = 1e-2 #1e-3
n_epochs = 100

image_encoder = VGG16Encoder() #Resnet50Encoder()
d_model = 25_088 #2048

model = ImageAggregator(
    d_model=d_model,
    aggregator=lambda x: torch.max(x, dim=-2).values
)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction="sum")
transform_function = lambda x: x.reshape(3, x.shape[0], x.shape[1])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image_encoder.to(device)

print(device)

In [None]:
trainer(
    model,
    image_encoder,
    optimizer,
    criterion,
    train_dataloader,
    val_dataloader,
    n_epochs,
    seq_len,
    device,
    "max_model"
)

In [None]:
all_ids, all_preds = predict(model, image_encoder, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, "submission_max_01.csv"), index=False)

In [None]:
trainer(
    model,
    image_encoder,
    optimizer,
    criterion,
    val_dataloader,
    val_dataloader,
    n_epochs,
    seq_len,
    device,
    "max_model"
)

In [None]:
all_ids, all_preds = predict(model, image_encoder, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, "submission_max_02.csv"), index=False)