In [41]:
import enum
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
from PIL import Image
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
#from utils_pretrains import *

In [42]:
import enum
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from PIL import Image
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


class Resnet18Encoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), x.size(1))
        return x


class Resnet50Encoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), x.size(1))
        return x


class VGG16Encoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = models.vgg16()
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), 512 * 7 * 7)
        return x


class BagImageModel(nn.Module):

    def __init__(
        self,
        d_model=2048,
        aggregator=lambda x: torch.mean(x, dim=-2),
        num_classes=2
    ):
        super().__init__()
        self.d_model = d_model
        self.aggregator = aggregator
        self.batch_norm = nn.BatchNorm1d(d_model)
        self.head = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x: torch.Tensor):
        batch_size, seq_len, d_model = x.size()
        if seq_len > 1:
            x = self.batch_norm(x.view(seq_len, d_model, batch_size))
        x = x.view(batch_size, seq_len, d_model)
        x = self.aggregator(x)
        x = self.head(x)
        return x


class PredictionAggregator(enum.Enum):
    MIN  = 0
    MAX  = 1
    MEAN = 2


class PatientImageDataset(Dataset):

    def __init__(self, 
        patient_path, 
        transform_function=None,
    ):
        super().__init__()
        self.patient_path = patient_path
        self.transform_function = transform_function

        images_path = []
        for image_path in os.listdir(patient_path):
            images_path.append(os.path.join(patient_path, image_path))

        self.images_path = images_path

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, index):
        image_path = self.images_path[index]
        image = image_to_numpy(image_path)
        if self.transform_function is not None:
            image = self.transform_function(image)
        return image


class PatientsDataset(Dataset):

    def __init__(
        self, 
        split_dir, 
        patients_ids,
        patients_labels
    ):
        super().__init__()
        self.split_dir = split_dir
        self.patients_ids = patients_ids
        self.patients_labels = patients_labels

    def __len__(self):
        return len(self.patients_ids)

    def __getitem__(self, index):
        patient_id = self.patients_ids[index]
        patient_label = self.patients_labels[index]
        patient_path = os.path.join(self.split_dir, patient_id)
        return {
            "id": patient_id, 
            "label": patient_label, 
            "patient_path": patient_path
        }


def image_to_numpy(image_path):
    image_pil = Image.open(image_path)
    image_np = np.array(image_pil, dtype=np.float32) / 255.0
    return image_np


def load_images(split_dir):
    images = {}

    for p in os.listdir(split_dir):
        p_dir = os.path.join(split_dir, p)
        if not os.path.isdir(p_dir):
            continue

        p_images = []
        for image_path in os.listdir(p_dir):
            p_images.append(os.path.join(p_dir, image_path))

        images[p] = p_images
    return images


def assign_labels(images_dict, labels_dict):
    images = []
    labels = []
    for p, p_images in images_dict.items():
        images.extend(p_images)
        labels.extend([labels_dict[p]] * len(p_images))
    return images, labels


def predict(
    model,
    image_encoder,
    dataloader,
    seq_len,
    transform_function,
    prediction_aggregator,
    device
):
    all_ids = []
    all_preds = []

    with torch.no_grad():
        for data in tqdm(dataloader, "Predict"):

            patient_id = data["id"][0]
            batch_preds = []

            patient_path = data["patient_path"][0]
            patient_dataset = PatientImageDataset(
                patient_path=patient_path,
                transform_function=transform_function,
            )
            patient_dataloader = DataLoader(patient_dataset, batch_size=seq_len, shuffle=False)

            for batch_images in patient_dataloader:
                batch_images = batch_images.unsqueeze(0).to(device)
                
                batch_size, seq_len, im_chan, im_width, im_height = batch_images.size()
                x = batch_images.view(batch_size * seq_len, im_chan, im_width, im_height)
                x = image_encoder(x)
                x = x.view(batch_size, seq_len, model.d_model)

                logits = model(x)
                preds = logits.argmax(dim=-1).cpu().detach().numpy()
                batch_preds.extend(preds)

            if prediction_aggregator is PredictionAggregator.MIN:
                patient_pred = np.min(batch_preds)
            elif prediction_aggregator is PredictionAggregator.MAX:
                patient_pred = np.max(batch_preds)
            else:
                patient_pred = int(np.mean(batch_preds) > 0.5)

            all_preds.append(patient_pred)
            all_ids.append(patient_id)

    return all_ids, all_preds


def test(
    model,
    image_encoder,
    criterion,
    dataloader,
    seq_len,
    transform_function,
    prediction_aggregator,
    device
):
    total_loss = 0
    all_preds = []
    all_true = []

    with torch.no_grad():
        for data in tqdm(dataloader, "Test"):

            patient_label_item = data["label"].item()
            patient_label = data["label"].to(device)
            batch_preds = []
            patient_loss = 0

            patient_path = data["patient_path"][0]
            patient_dataset = PatientImageDataset(
                patient_path=patient_path,
                transform_function=transform_function,
            )
            patient_dataloader = DataLoader(patient_dataset, batch_size=seq_len, shuffle=False)

            for batch_images in patient_dataloader:
                batch_images = batch_images.unsqueeze(0).to(device)

                batch_size, seq_len, im_chan, im_width, im_height = batch_images.size()
                x = batch_images.view(batch_size * seq_len, im_chan, im_width, im_height)
                x = image_encoder(x)
                x = x.view(batch_size, seq_len, model.d_model)

                logits = model(x)
                loss = criterion(logits, patient_label)

                patient_loss += loss.item()
                preds = logits.argmax(dim=-1).cpu().detach().numpy()
                batch_preds.extend(preds)

            patient_loss /= len(patient_dataset)
            total_loss += patient_loss
            
            if prediction_aggregator is PredictionAggregator.MIN:
                patient_pred = np.min(batch_preds)
            elif prediction_aggregator is PredictionAggregator.MAX:
                patient_pred = np.max(batch_preds)
            else:
                patient_pred = int(np.mean(batch_preds) > 0.5)

            all_preds.append(patient_pred)
            all_true.append(patient_label_item)

    total_loss /= len(dataloader.dataset)
    acc = metrics.accuracy_score(all_true, all_preds)
    bacc = metrics.balanced_accuracy_score(all_true, all_preds)

    return total_loss, acc, bacc


def train(
    model,
    image_encoder,
    optimizer,
    criterion,
    dataloader,
    seq_len,
    transform_function,
    prediction_aggregator,
    device
):
    total_loss = 0
    all_preds = []
    all_true = []

    for data in tqdm(dataloader, "Train"):

        patient_label_item = data["label"].item()
        patient_label = data["label"].to(device)
        batch_preds = []
        patient_loss = 0

        patient_path = data["patient_path"][0]
        patient_dataset = PatientImageDataset(
            patient_path=patient_path,
            transform_function=transform_function,
        )
        patient_dataloader = DataLoader(patient_dataset, batch_size=seq_len, shuffle=True)

        for batch_images in patient_dataloader:
            batch_images = batch_images.unsqueeze(0).to(device)
            
            with torch.no_grad():
                batch_size, seq_len, im_chan, im_width, im_height = batch_images.size()
                x = batch_images.view(batch_size * seq_len, im_chan, im_width, im_height)
                x = image_encoder(x)
                x = x.view(batch_size, seq_len, model.d_model)

            logits = model(x)
            loss = criterion(logits, patient_label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            patient_loss += loss.item()
            preds = logits.argmax(dim=-1).cpu().detach().numpy()
            batch_preds.extend(preds)

        patient_loss /= len(patient_dataset)
        total_loss += patient_loss

        if prediction_aggregator is PredictionAggregator.MIN:
            patient_pred = np.min(batch_preds)
        elif prediction_aggregator is PredictionAggregator.MAX:
            patient_pred = np.max(batch_preds)
        else:
            patient_pred = int(np.mean(batch_preds) > 0.5)

        all_preds.append(patient_pred)
        all_true.append(patient_label_item)

    total_loss /= len(dataloader.dataset)
    acc = metrics.accuracy_score(all_true, all_preds)
    bacc = metrics.balanced_accuracy_score(all_true, all_preds)

    return total_loss, acc, bacc


def trainer(
    model, 
    image_encoder,
    optimizer, 
    criterion, 
    train_dataloader, 
    test_dataloader, 
    n_epochs,
    seq_len,
    transform_function,
    prediction_aggregator,
    device,
    model_name,
    verbose=True,
    verbose_every=1,
    save=True,
    save_every=1,
    save_dir=".",
    plot=True
):
    train_accs      = []
    train_baccs     = []
    train_losses    = []
    test_accs       = []
    test_baccs      = []
    test_losses     = []

    for epoch in range(1, n_epochs + 1):
        train_loss, train_acc, train_bacc = train(
            model, image_encoder, optimizer, criterion, train_dataloader,
            seq_len, transform_function, prediction_aggregator, device
        )
        test_loss, test_acc, test_bacc = test(
            model, image_encoder, criterion, test_dataloader,
            seq_len, transform_function, prediction_aggregator, device
        )

        train_accs.append(train_acc)
        train_baccs.append(train_bacc)
        train_losses.append(train_loss)

        test_accs.append(test_acc)
        test_baccs.append(test_bacc)
        test_losses.append(test_loss)

        if verbose and epoch % verbose_every == 0:
            print(
                f"[Epoch {epoch} / {n_epochs}]",
                f"\ttrain loss = {train_loss:.4f} acc = {train_acc:.4f} bacc = {train_bacc:.4f}",
                f"\ttest loss = {test_loss:.4f} acc = {test_acc:.4f} bacc = {test_bacc:.4f}", 
                sep="\n"
            )

        if save and epoch % save_every == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, 
                os.path.join(save_dir, f"{model_name}.pt")
            )

    if plot:
        plt.figure()
        plt.plot(train_accs, label="Train")
        plt.plot(test_accs, label="Test")
        plt.title("Accuracy")
        plt.show()

        plt.figure()
        plt.plot(train_baccs, label="Train")
        plt.plot(test_baccs, label="Test")
        plt.title("Balanced accuracy")
        plt.show()

        plt.figure()
        plt.plot(train_losses, label="Train")
        plt.plot(test_losses, label="Test")
        plt.title("Loss")
        plt.show()
    

def load_checkpoint(checkpoint_path, model_args={}, optimizer_args={}):
    model = BagImageModel(**model_args)
    optimizer = optim.Adam(model.parameters(), **optimizer_args)

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    
    return model, optimizer, epoch

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [44]:
data_dir = "/kaggle/input/dlmi-lymphocytosis/dlmi-lymphocytosis-classification/"
#data_dir = "../../data/dlmi-lymphocytosis-classification/"
train_dir = data_dir + "trainset/"
test_dir = data_dir + "testset/"
save_dir = "/kaggle/working/"

In [45]:
trainset_true_df = pd.read_csv(data_dir + "trainset/trainset_true.csv")
testset_data_df = pd.read_csv(data_dir + "testset/testset_data.csv")

In [46]:
id_train_val = trainset_true_df["ID"].to_numpy()
labels = trainset_true_df["LABEL"].to_numpy()
labels_dict = dict(zip(id_train_val, labels))

id_train_val_0 = id_train_val[labels == 0]
id_train_val_1 = id_train_val[labels == 1]

n_id_val_per_class = 10
id_val_0 = np.random.choice(id_train_val_0, size=n_id_val_per_class, replace=False)
id_val_1 = np.random.choice(id_train_val_1, size=n_id_val_per_class, replace=False)

id_val = np.concatenate([id_val_0, id_val_1])
mask = np.isin(id_train_val, id_val, invert=True)
id_train = id_train_val[mask]

id_train.shape, id_val.shape

((143,), (20,))

In [47]:
batch_size = 1
seq_len = 128

In [48]:
train_dataset = PatientsDataset(
    split_dir=train_dir,
    patients_ids=id_train,
    patients_labels=np.array([labels_dict[id] for id in id_train]),
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = PatientsDataset(
    split_dir=train_dir,
    patients_ids=id_val,
    patients_labels=np.array([labels_dict[id] for id in id_val]),
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_val_dataset = PatientsDataset(
    split_dir=train_dir,
    patients_ids=id_train_val,
    patients_labels=labels,
)
train_val_dataloader = DataLoader(train_val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = PatientsDataset(
    split_dir=test_dir,
    patients_ids=testset_data_df["ID"].to_numpy(),
    patients_labels=[-1] * len(trainset_true_df["ID"]),
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Mean

In [49]:
lr = 1e-2
n_epochs = 100

image_encoder = Resnet18Encoder()
image_encoder.to(device)
d_model = 512
model = BagImageModel(
    d_model=512,
    aggregator=lambda x: torch.mean(x, dim=-2)
)
model.to(device)
model_name = "ResNet18_mean"

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction="sum")
transform_function = lambda x: x.reshape(3, x.shape[0], x.shape[1])
prediction_aggregator = PredictionAggregator.MIN

In [None]:
trainer(
    model, 
    image_encoder,
    optimizer, 
    criterion, 
    train_dataloader, 
    val_dataloader, 
    n_epochs,
    seq_len,
    transform_function,
    prediction_aggregator,
    device,
    model_name,
    verbose=True,
    verbose_every=1,
    save=True,
    save_every=1,
    save_dir=save_dir,
    plot=True
)

Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 1 / 100]
	train loss = 0.7617 acc = 0.7552 bacc = 0.6848
	test loss = 0.9657 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.33it/s]


[Epoch 2 / 100]
	train loss = 0.5090 acc = 0.7343 bacc = 0.6091
	test loss = 1.1696 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.36it/s]


[Epoch 3 / 100]
	train loss = 2.4526 acc = 0.7133 bacc = 0.6022
	test loss = 0.6466 acc = 0.6500 bacc = 0.6500


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.29it/s]


[Epoch 4 / 100]
	train loss = 1.2189 acc = 0.7203 bacc = 0.6223
	test loss = 0.5904 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.82it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.47it/s]


[Epoch 5 / 100]
	train loss = 0.5839 acc = 0.6923 bacc = 0.5723
	test loss = 1.1601 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.32it/s]


[Epoch 6 / 100]
	train loss = 0.5895 acc = 0.7552 bacc = 0.6466
	test loss = 0.5891 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.47it/s]


[Epoch 7 / 100]
	train loss = 0.5635 acc = 0.7273 bacc = 0.6501
	test loss = 0.6330 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 8 / 100]
	train loss = 0.5762 acc = 0.7343 bacc = 0.6473
	test loss = 0.7775 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.82it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


[Epoch 9 / 100]
	train loss = 0.5185 acc = 0.7343 bacc = 0.6167
	test loss = 0.6512 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.30it/s]


[Epoch 10 / 100]
	train loss = 0.5650 acc = 0.7483 bacc = 0.6417
	test loss = 1.0085 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 11 / 100]
	train loss = 0.5404 acc = 0.7483 bacc = 0.6570
	test loss = 0.9940 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.48it/s]


[Epoch 12 / 100]
	train loss = 0.5373 acc = 0.7552 bacc = 0.6695
	test loss = 1.1860 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.82it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.34it/s]


[Epoch 13 / 100]
	train loss = 0.5929 acc = 0.7203 bacc = 0.5994
	test loss = 1.5236 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:19<00:00,  1.80it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 14 / 100]
	train loss = 0.5657 acc = 0.7483 bacc = 0.6417
	test loss = 0.8896 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 15 / 100]
	train loss = 0.5180 acc = 0.7622 bacc = 0.6515
	test loss = 0.6958 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.47it/s]


[Epoch 16 / 100]
	train loss = 0.5744 acc = 0.7413 bacc = 0.6369
	test loss = 1.2759 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 17 / 100]
	train loss = 0.5882 acc = 0.7203 bacc = 0.6223
	test loss = 1.3172 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.33it/s]


[Epoch 18 / 100]
	train loss = 0.5328 acc = 0.7133 bacc = 0.5716
	test loss = 1.3433 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 19 / 100]
	train loss = 0.6041 acc = 0.7203 bacc = 0.5917
	test loss = 0.6154 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.36it/s]


[Epoch 20 / 100]
	train loss = 0.5298 acc = 0.7273 bacc = 0.6731
	test loss = 1.2693 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 21 / 100]
	train loss = 0.5186 acc = 0.7902 bacc = 0.6556
	test loss = 0.6286 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


[Epoch 22 / 100]
	train loss = 0.5657 acc = 0.7273 bacc = 0.6195
	test loss = 0.7099 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


[Epoch 23 / 100]
	train loss = 0.5649 acc = 0.7273 bacc = 0.6195
	test loss = 0.7310 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.90it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 24 / 100]
	train loss = 0.5326 acc = 0.7483 bacc = 0.6341
	test loss = 1.0378 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.35it/s]


[Epoch 25 / 100]
	train loss = 0.5437 acc = 0.7203 bacc = 0.5994
	test loss = 0.8318 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.82it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 26 / 100]
	train loss = 0.5313 acc = 0.7552 bacc = 0.6619
	test loss = 1.4627 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.89it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


[Epoch 27 / 100]
	train loss = 0.5891 acc = 0.7133 bacc = 0.5792
	test loss = 1.2858 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


[Epoch 28 / 100]
	train loss = 0.5516 acc = 0.7483 bacc = 0.6800
	test loss = 0.8567 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.38it/s]


[Epoch 29 / 100]
	train loss = 0.5894 acc = 0.7203 bacc = 0.5917
	test loss = 0.6835 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 30 / 100]
	train loss = 0.5758 acc = 0.7413 bacc = 0.6369
	test loss = 0.9373 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:19<00:00,  1.80it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 31 / 100]
	train loss = 0.5419 acc = 0.7622 bacc = 0.6667
	test loss = 0.5963 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:22<00:00,  1.73it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.39it/s]


[Epoch 32 / 100]
	train loss = 0.5423 acc = 0.7063 bacc = 0.6203
	test loss = 0.6031 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.90it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.44it/s]


[Epoch 33 / 100]
	train loss = 0.5779 acc = 0.7133 bacc = 0.5945
	test loss = 1.0075 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 34 / 100]
	train loss = 0.5655 acc = 0.7133 bacc = 0.5945
	test loss = 1.1282 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.50it/s]


[Epoch 35 / 100]
	train loss = 0.6019 acc = 0.7133 bacc = 0.5869
	test loss = 0.8866 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.93it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 36 / 100]
	train loss = 0.5575 acc = 0.7343 bacc = 0.6091
	test loss = 0.8315 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.92it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.47it/s]


[Epoch 37 / 100]
	train loss = 0.5407 acc = 0.7413 bacc = 0.6675
	test loss = 1.0032 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.53it/s]


[Epoch 38 / 100]
	train loss = 0.5901 acc = 0.7273 bacc = 0.5813
	test loss = 0.6196 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.90it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.38it/s]


[Epoch 39 / 100]
	train loss = 0.5492 acc = 0.7552 bacc = 0.6466
	test loss = 1.1057 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.89it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.48it/s]


[Epoch 40 / 100]
	train loss = 0.5461 acc = 0.7622 bacc = 0.6820
	test loss = 1.0733 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 41 / 100]
	train loss = 0.5391 acc = 0.7413 bacc = 0.6522
	test loss = 1.2964 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:13<00:00,  1.95it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.46it/s]


[Epoch 42 / 100]
	train loss = 0.4951 acc = 0.7413 bacc = 0.6369
	test loss = 1.0632 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


[Epoch 43 / 100]
	train loss = 0.5426 acc = 0.7622 bacc = 0.6591
	test loss = 0.9136 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.51it/s]


[Epoch 44 / 100]
	train loss = 0.5089 acc = 0.7622 bacc = 0.6897
	test loss = 0.6444 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.89it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.46it/s]


[Epoch 45 / 100]
	train loss = 0.5338 acc = 0.7552 bacc = 0.6313
	test loss = 0.7411 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.90it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.48it/s]


[Epoch 46 / 100]
	train loss = 0.5487 acc = 0.7483 bacc = 0.6417
	test loss = 0.8059 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.92it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.44it/s]


[Epoch 47 / 100]
	train loss = 0.5292 acc = 0.6993 bacc = 0.6231
	test loss = 0.5910 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 48 / 100]
	train loss = 0.5332 acc = 0.7413 bacc = 0.6598
	test loss = 1.8271 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.47it/s]


[Epoch 49 / 100]
	train loss = 0.5772 acc = 0.7343 bacc = 0.6167
	test loss = 0.8185 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.50it/s]


[Epoch 50 / 100]
	train loss = 0.5662 acc = 0.7273 bacc = 0.6195
	test loss = 0.6429 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:13<00:00,  1.94it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 51 / 100]
	train loss = 0.4920 acc = 0.7762 bacc = 0.6841
	test loss = 0.9126 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.90it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 52 / 100]
	train loss = 0.5919 acc = 0.6853 bacc = 0.5598
	test loss = 0.8863 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.52it/s]


[Epoch 53 / 100]
	train loss = 0.5648 acc = 0.7273 bacc = 0.5966
	test loss = 0.6314 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 54 / 100]
	train loss = 0.5620 acc = 0.7063 bacc = 0.6050
	test loss = 0.7656 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.46it/s]


[Epoch 55 / 100]
	train loss = 0.5394 acc = 0.7483 bacc = 0.6570
	test loss = 1.1669 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.48it/s]


[Epoch 56 / 100]
	train loss = 0.5592 acc = 0.7203 bacc = 0.6300
	test loss = 0.7055 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 57 / 100]
	train loss = 0.5518 acc = 0.7133 bacc = 0.6098
	test loss = 1.2055 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.37it/s]


[Epoch 58 / 100]
	train loss = 0.6050 acc = 0.7343 bacc = 0.6015
	test loss = 1.1133 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.48it/s]


[Epoch 59 / 100]
	train loss = 0.5814 acc = 0.7483 bacc = 0.6494
	test loss = 0.5813 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.50it/s]


[Epoch 60 / 100]
	train loss = 0.5105 acc = 0.7622 bacc = 0.6744
	test loss = 0.7159 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.35it/s]


[Epoch 61 / 100]
	train loss = 0.5609 acc = 0.7413 bacc = 0.6063
	test loss = 0.9522 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.36it/s]


[Epoch 62 / 100]
	train loss = 0.5354 acc = 0.7622 bacc = 0.6515
	test loss = 0.9248 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.50it/s]


[Epoch 63 / 100]
	train loss = 0.5909 acc = 0.7273 bacc = 0.6272
	test loss = 0.7600 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


[Epoch 64 / 100]
	train loss = 0.5973 acc = 0.6993 bacc = 0.6001
	test loss = 1.1172 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:13<00:00,  1.94it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


[Epoch 65 / 100]
	train loss = 0.5486 acc = 0.7483 bacc = 0.6494
	test loss = 0.8532 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.89it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.56it/s]


[Epoch 66 / 100]
	train loss = 0.5546 acc = 0.7063 bacc = 0.5744
	test loss = 0.5796 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]


[Epoch 67 / 100]
	train loss = 0.5404 acc = 0.7552 bacc = 0.6313
	test loss = 0.7462 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.90it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.35it/s]


[Epoch 68 / 100]
	train loss = 0.5347 acc = 0.7483 bacc = 0.6417
	test loss = 0.9242 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 69 / 100]
	train loss = 0.5600 acc = 0.7413 bacc = 0.6369
	test loss = 0.6997 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


[Epoch 70 / 100]
	train loss = 0.5463 acc = 0.7552 bacc = 0.6466
	test loss = 1.3813 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.83it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


[Epoch 71 / 100]
	train loss = 0.5679 acc = 0.7483 bacc = 0.6570
	test loss = 1.2065 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.53it/s]


[Epoch 72 / 100]
	train loss = 0.5921 acc = 0.7413 bacc = 0.6063
	test loss = 0.6082 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 73 / 100]
	train loss = 0.5084 acc = 0.7762 bacc = 0.6917
	test loss = 1.0153 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]


[Epoch 74 / 100]
	train loss = 0.5793 acc = 0.7273 bacc = 0.6348
	test loss = 0.6293 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


[Epoch 75 / 100]
	train loss = 0.5772 acc = 0.7343 bacc = 0.6091
	test loss = 1.1547 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


[Epoch 76 / 100]
	train loss = 0.5601 acc = 0.7273 bacc = 0.5813
	test loss = 0.6243 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.51it/s]


[Epoch 77 / 100]
	train loss = 0.5391 acc = 0.7622 bacc = 0.6667
	test loss = 0.5997 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.39it/s]


[Epoch 78 / 100]
	train loss = 0.5643 acc = 0.7343 bacc = 0.6397
	test loss = 0.9046 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:15<00:00,  1.89it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.47it/s]


[Epoch 79 / 100]
	train loss = 0.5421 acc = 0.7343 bacc = 0.6091
	test loss = 0.5791 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.88it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.53it/s]


[Epoch 80 / 100]
	train loss = 0.5386 acc = 0.7413 bacc = 0.6445
	test loss = 1.4647 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.37it/s]


[Epoch 81 / 100]
	train loss = 0.5776 acc = 0.7552 bacc = 0.6466
	test loss = 0.5936 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.84it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.39it/s]


[Epoch 82 / 100]
	train loss = 0.5653 acc = 0.7343 bacc = 0.6015
	test loss = 1.2080 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.36it/s]


[Epoch 83 / 100]
	train loss = 0.5723 acc = 0.7343 bacc = 0.6320
	test loss = 0.6541 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.82it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.38it/s]


[Epoch 84 / 100]
	train loss = 0.5761 acc = 0.7203 bacc = 0.5994
	test loss = 0.5793 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


[Epoch 85 / 100]
	train loss = 0.5666 acc = 0.7203 bacc = 0.6070
	test loss = 0.6457 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:17<00:00,  1.85it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


[Epoch 86 / 100]
	train loss = 0.5788 acc = 0.7413 bacc = 0.6140
	test loss = 0.7543 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.50it/s]


[Epoch 87 / 100]
	train loss = 0.5655 acc = 0.7203 bacc = 0.6070
	test loss = 1.4710 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:18<00:00,  1.83it/s]
Test: 100%|██████████| 20/20 [00:07<00:00,  2.50it/s]


[Epoch 88 / 100]
	train loss = 0.6030 acc = 0.7273 bacc = 0.6272
	test loss = 1.1359 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:14<00:00,  1.91it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


[Epoch 89 / 100]
	train loss = 0.5370 acc = 0.7133 bacc = 0.6022
	test loss = 1.4140 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.87it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


[Epoch 90 / 100]
	train loss = 0.5997 acc = 0.7273 bacc = 0.5966
	test loss = 0.6141 acc = 0.5000 bacc = 0.5000


Train: 100%|██████████| 143/143 [01:16<00:00,  1.86it/s]
Test: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


[Epoch 91 / 100]
	train loss = 0.5550 acc = 0.7273 bacc = 0.6501
	test loss = 0.5949 acc = 0.5000 bacc = 0.5000


Train:  80%|███████▉  | 114/143 [00:57<00:17,  1.67it/s]

In [None]:
all_ids, all_preds = predict(model, image_encoder, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, model_name + "01.csv"), index=False)

### Min

In [None]:
lr = 1e-2
n_epochs = 100

image_encoder = Resnet18Encoder()
image_encoder.to(device)
d_model = 512
model = BagImageModel(
    d_model=512,
    aggregator=lambda x: torch.min(x, dim=-2).values
)
model.to(device)
model_name = "ResNet18_min"

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction="sum")
transform_function = lambda x: x.reshape(3, x.shape[0], x.shape[1])
prediction_aggregator = PredictionAggregator.MIN

In [None]:
trainer(
    model, 
    image_encoder,
    optimizer, 
    criterion, 
    train_dataloader, 
    val_dataloader, 
    n_epochs,
    seq_len,
    transform_function,
    prediction_aggregator,
    device,
    model_name,
    verbose=True,
    verbose_every=1,
    save=True,
    save_every=1,
    save_dir=save_dir,
    plot=True
)

In [None]:
all_ids, all_preds = predict(model, image_encoder, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, model_name + "01.csv"), index=False)

### Max

In [None]:
lr = 1e-2
n_epochs = 100

image_encoder = Resnet18Encoder()
image_encoder.to(device)
d_model = 512
model = BagImageModel(
    d_model=512,
    aggregator=lambda x: torch.max(x, dim=-2).values
)
model.to(device)
model_name = "ResNet18_max"

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction="sum")
transform_function = lambda x: x.reshape(3, x.shape[0], x.shape[1])
prediction_aggregator = PredictionAggregator.MIN

In [None]:
trainer(
    model, 
    image_encoder,
    optimizer, 
    criterion, 
    train_dataloader, 
    val_dataloader, 
    n_epochs,
    seq_len,
    transform_function,
    prediction_aggregator,
    device,
    model_name,
    verbose=True,
    verbose_every=1,
    save=True,
    save_every=1,
    save_dir=save_dir,
    plot=True
)

In [None]:
all_ids, all_preds = predict(model, image_encoder, test_dataloader, seq_len, device)
submission_df = pd.DataFrame({"ID": all_ids, "Predicted": all_preds})
submission_df.to_csv(os.path.join(save_dir, model_name + "01.csv"), index=False)