-----
----
# <b> DLMI Challenge </b>
# <b> MIL Instance </b>
# <b> Matteo MARENGO | matteo.marengo@ens-paris-saclay.fr </b>
# <b> Manal MEFTAH | manal.meftah@ens-paris-saclay.fr </b>


----
----
# <b> Import libraries </b>

In [None]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import numpy as np
from sklearn.metrics import f1_score, recall_score, precision_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

----
----
# <b> Dataset class </b>

In [None]:
class PatientImagesDataset(Dataset):
    def __init__(self, root_dir, annotations_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Read the annotations
        annotations = pd.read_csv(annotations_file)
        self.images = []
        self.labels = []
        # Assign labels to each image based on its bag
        for index, row in annotations.iterrows():
            img_folder = os.path.join(self.root_dir, str(row['ID']))
            for img_name in os.listdir(img_folder):
                if img_name.endswith('.jpg'):
                    self.images.append(os.path.join(img_folder, img_name))
                    self.labels.append(row['LABEL'])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(int(self.labels[idx]))
        return image, label

----
----
# <b> Custom ResNet and MIL class </b>

In [None]:
class CustomResNet(nn.Module):
    def __init__(self, K):
        super(CustomResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, K, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(K)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2 = self._make_layer(K, K, stride=1)
        self.conv3 = self._make_layer(K, 2*K, stride=2)
        self.conv4 = self._make_layer(2*K, 4*K, stride=2)
        self.conv5 = self._make_layer(4*K, 8*K, stride=2)


    def _make_layer(self, in_channels, out_channels, stride):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        ]
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        x = x.view(x.size(0), -1)

        return x

class MILModel(nn.Module):
    def __init__(self, K=64):
        super(MILModel, self).__init__()
        self.feature_extractor = CustomResNet(K)
        self.classifier = nn.Sequential(
            nn.Linear(7 * 7 * 8 * K, 1024),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = self.feature_extractor(x)  # Extract features from the image
        x = x.view(x.size(0), -1)  # Flatten the output for the classifier
        x = self.classifier(x)  # Classify the image
        return torch.sigmoid(x)


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = PatientImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset', annotations_file='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset/trainset_true.csv', transform=transform)
train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.labels)
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=32, sampler=val_sampler)


# Initialize model, optimizer, and criterion
model = MILModel(K=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = nn.BCELoss()

train_losses = []
val_losses = []
val_accuracies = []
val_f1_scores = []
val_recall_scores = []
val_precision_scores = []

# Training loop
num_epochs = 150
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, desc="Training")
    for images, labels in train_loader_tqdm:
        images = images.to(device)
        labels = labels.to(device).float().view(-1, 1)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_loader_tqdm.set_postfix({"Train Loss": f"{loss.item():.4f}"})

    print(f"Training Loss: {train_loss/len(train_loader)}")

    # Validation loop
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        val_loader_tqdm = tqdm(val_loader, desc="Validation")
        for images, labels in val_loader_tqdm:
            images = images.to(device)
            labels = labels.to(device).float().view(-1, 1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = outputs.round().cpu().numpy()
            all_preds.extend(preds.flatten().tolist())
            all_labels.extend(labels.cpu().numpy().flatten().tolist())

    # Calculate validation metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    val_acc = balanced_accuracy_score(all_labels, all_preds)

    # Store metrics for visualization or further analysis
    train_losses.append(train_loss / len(train_loader))
    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(val_acc)
    val_f1_scores.append(f1)
    val_recall_scores.append(recall)
    val_precision_scores.append(precision)

    # Print validation metrics to monitor performance
    print(f'Epoch {epoch+1}/{num_epochs} - Training Loss: {train_loss/len(train_loader):.4f}, Validation Loss: {val_loss/len(val_loader):.4f}, Val Balanced Acc: {val_acc:.4f}, F1: {f1:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}')

# Save model weights
torch.save(model.state_dict(), '/kaggle/working/model_MIL_Instance_Unbalanced.pth')

# Plot training and validation metrics
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.title('Losses')
plt.savefig('/kaggle/working/Train_Validation_Losses_MIL_Instance_Unbalanced.png')

plt.subplot(1, 3, 2)
plt.plot(val_accuracies, label='Val Balanced Accuracy')
plt.legend()
plt.title('Balanced Accuracy')
plt.savefig('/kaggle/working/balanced_accuracy_MIL_Instance_Unbalanced.png')

plt.subplot(1, 3, 3)
plt.plot(val_f1_scores, label='F1 Score')
plt.plot(val_recall_scores, label='Recall')
plt.plot(val_precision_scores, label='Precision')
plt.legend()
plt.title('Validation Metrics')
plt.savefig('/kaggle/working/metrics_MIL_Instance_Unbalanced.png')

plt.show()



----
----
# <b> Predict on the test set </b>

In [None]:
class PatientTestImagesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.folders = [os.path.join(root_dir, o) for o in os.listdir(root_dir)
                        if os.path.isdir(os.path.join(root_dir,o))]

        self.images = []
        self.patients = []

        for folder in self.folders:
            patient_id = os.path.basename(folder)
            for img_name in os.listdir(folder):
                if img_name.endswith('.jpg'):
                    self.images.append(os.path.join(folder, img_name))
                    self.patients.append(patient_id)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        if self.transform:
            image = self.transform(image)
        patient_id = self.patients[idx]
        return image, patient_id

test_dataset = PatientTestImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model.eval()
patient_predictions = {}
with torch.no_grad():
    for images, patient_ids in tqdm(test_loader, desc="Predicting Test Data"):
        images = images.to(device)
        outputs = model(images)
        predictions = outputs.round().cpu().numpy()

        for patient_id, prediction in zip(patient_ids, predictions):
            if patient_id not in patient_predictions:
                patient_predictions[patient_id] = []
            patient_predictions[patient_id].append(int(prediction))

final_predictions = []
for patient_id, preds in patient_predictions.items():
    majority_class = round(sum(preds) / len(preds))
    final_predictions.append((patient_id, majority_class))

df_predictions = pd.DataFrame(final_predictions, columns=['Id', 'Predicted'])

df_predictions.to_csv('/kaggle/working/Predicted_Instance_Mil_CustomResNet_Unbalanced.csv', index=False)

final_predictions = []

## We add a threshold

for patient_id, preds in patient_predictions.items():
    zero_class_percentage = preds.count(0) / len(preds)
    if zero_class_percentage > 0.25:
        predicted_class = 0  # Assign class 0 if more than 25% of the predictions are 0
    else:
        predicted_class = 1
    final_predictions.append((patient_id, predicted_class))

df_predictions = pd.DataFrame(final_predictions, columns=['Id', 'Predicted'])

df_predictions.to_csv('/kaggle/working/Predicted_Instance_Mil_CustomResNet_Balanced.csv', index=False)


----
----
# <b> Balanced/Weighted Loss </b>

In [None]:
# Model definition
class MILModel(nn.Module):
    def __init__(self, K=64):
        super(MILModel, self).__init__()
        self.feature_extractor = CustomResNet(K)
        #self.classifier = nn.Linear(7*7*8*K, 1)
        self.classifier = nn.Sequential(
            nn.Linear(7 * 7 * 8 * K, 1024),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        embeddings = self.feature_extractor(x)  # Process each instance in the bag.
        embedding = embeddings.mean(dim=0, keepdim=True)  # fMEAN pooling over instances in the bag.
        output = self.classifier(embedding)
        #return output
        return torch.sigmoid(output)

# Prepare dataset and dataloader
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor()])

dataset = PatientImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset', annotations_file='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset/trainset_true.csv', transform=transform)
train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.annotations['LABEL'])
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

def custom_collate_fn(batch):
    bags, labels = zip(*batch)
    # Convert labels to a tensor.
    labels = torch.tensor(labels)
    return bags, labels

# Update DataLoader initialization with custom collate function
train_loader = DataLoader(dataset, batch_size=1, sampler=train_sampler, collate_fn=custom_collate_fn)
val_loader = DataLoader(dataset, batch_size=1, sampler=val_sampler, collate_fn=custom_collate_fn)

# Initialize model, optimizer, and criterion
model = MILModel(K=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# pos_weight = torch.tensor([4]).to(device)
#criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([0.25]).to(device))
criterion = nn.BCELoss()

train_losses = []
val_losses = []
val_accuracies = []
val_f1_scores = []
val_recall_scores = []
val_precision_scores = []

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, desc="Training")
    for batch in train_loader_tqdm:
        bag, label = batch[0][0], batch[1]
        bag = bag.to(device)
        #print(bag.shape)
        #print(label)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(bag)
        #print(output)
        loss = criterion(output[0], label.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_loader_tqdm.set_postfix({"Train Loss": f"{loss.item():.4f}"})

    print(f"Training Loss: {train_loss/len(train_loader)}")

    # Validation loop
    model.eval()
    all_preds = []
    all_labels = []
    val_loss = 0.0

    with torch.no_grad():
        val_loader_tqdm = tqdm(val_loader, desc="Validation")
        for batch in val_loader_tqdm:
            bag, label = batch[0][0], batch[1]
            bag = bag.to(device)
            label = label.to(device)
            output = model(bag)
            #pred_output = torch.sigmoid(output)

            loss = criterion(output[0], label.float())
            val_loss += loss.item()
            preds = output[0].round()
            print(preds)
            all_preds.extend(preds.cpu().numpy().flatten().tolist())
            all_labels.extend(label.cpu().numpy().flatten().tolist())

    # Calculate validation metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')

    val_acc = balanced_accuracy_score(all_labels, all_preds)

    # Store metrics
    train_losses.append(train_loss/len(train_loader))
    val_losses.append(val_loss/len(val_loader))
    val_accuracies.append(val_acc)
    val_f1_scores.append(f1)
    val_recall_scores.append(recall)
    val_precision_scores.append(precision)

    # Print validation metrics
    print(f'Validation F1 Score: {f1}')
    print(f'Validation Recall: {recall}')
    print(f'Validation Precision: {precision}')

    val_acc = balanced_accuracy_score(all_labels, all_preds)
    print(f'Validation Loss: {val_loss/len(val_loader)}, Val Balanced Acc: {val_acc}')

# Save model weights
torch.save(model.state_dict(), '/kaggle/working/model_MIL_Embedding_Unbalanced.pth')

# Plot training and validation metrics
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.title('Losses')
plt.savefig('/kaggle/working/Train_Validation_Losses_MIL_Embedding_Unbalanced.png')

plt.subplot(1, 3, 2)
plt.plot(val_accuracies, label='Val Balanced Accuracy')
plt.legend()
plt.title('Balanced Accuracy')
plt.savefig('/kaggle/working/balanced_accuracy_MIL_Embedding_Unbalanced.png')

plt.subplot(1, 3, 3)
plt.plot(val_f1_scores, label='F1 Score')
plt.plot(val_recall_scores, label='Recall')
plt.plot(val_precision_scores, label='Precision')
plt.legend()
plt.title('Validation Metrics')
plt.savefig('/kaggle/working/metrics_MIL_Embedding_Unbalanced.png')

plt.show()

----
----
# <b> MIL INstance with attention pooling </b>

In [None]:
# THE ATTENTION POOLING
class AttentionPooling(nn.Module):
    def __init__(self, in_features):
        super(AttentionPooling, self).__init__()
        self.attention_vector = nn.Parameter(torch.empty(in_features, dtype=torch.float32))
        nn.init.xavier_uniform_(self.attention_vector.unsqueeze(0))

    def forward(self, x):
        # Compute attention scores and apply softmax
        attention_scores = torch.matmul(x, self.attention_vector)
        attention_weights = F.softmax(attention_scores, dim=0)

        # Compute the weighted average
        weighted_features = x * attention_weights.unsqueeze(-1)
        pooled_features = weighted_features.sum(dim=0)
        return pooled_features

class MILModel(nn.Module):
    def __init__(self, K=64):
        super(MILModel, self).__init__()
        #self.feature_extractor = CustomResNet(K)
        self.feature_extractor = VisionTransformer(2)
        self.attention_pooling = AttentionPooling(K*7*7*8)
        #self.classifier = nn.Linear(7*7*8*K, 1)
        self.classifier = nn.Sequential(
            nn.Linear(7*7*8*K, 512),  # First linear layer
            nn.ReLU(),                # Non-linearity
            nn.Linear(512, 128),      # Second linear layer
            nn.ReLU(),                # Non-linearity
            nn.Linear(128, 1)         # Final layer to output
        )

    def forward(self, x):
        embeddings = self.feature_extractor(x)
        #embedding = embeddings.mean(dim=0, keepdim=True)
        embedding = self.attention_pooling(embeddings)
        output = self.classifier(embedding)
        return output
        #return torch.sigmoid(output)

# Prepare dataset and dataloader
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor()])

dataset = PatientImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset', annotations_file='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset/trainset_true.csv', transform=transform)
train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.annotations['LABEL'])
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

# Custom collate function
def custom_collate_fn(batch):
    bags, labels = zip(*batch)
    labels = torch.tensor(labels)
    return bags, labels

# Update DataLoader initialization with custom collate function
train_loader = DataLoader(dataset, batch_size=1, sampler=train_sampler, collate_fn=custom_collate_fn)
val_loader = DataLoader(dataset, batch_size=1, sampler=val_sampler, collate_fn=custom_collate_fn)

# Initialize model, optimizer, and criterion
model = MILModel(K=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# pos_weight = torch.tensor([4]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]).to(device))
#criterion = nn.BCELoss()

# Initialize learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20000, gamma=0.1)


train_losses = []
val_losses = []
val_accuracies = []
val_f1_scores = []
val_recall_scores = []
val_precision_scores = []

# Training loop
num_epochs = 150
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, desc="Training")
    for batch in train_loader_tqdm:
        bag, label = batch[0][0], batch[1]
        bag = bag.to(device)
        #print(bag.shape)
        #print(label)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(bag)
        #print(output[0])
        loss = criterion(output, label.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_loader_tqdm.set_postfix({"Train Loss": f"{loss.item():.4f}"})

    print(f"Epoch {epoch+1}, Training Loss: {train_loss/len(train_loader)}")
    lr_scheduler.step()

    # Validation loop
    model.eval()
    all_preds = []
    all_labels = []
    val_loss = 0.0

    with torch.no_grad():
        val_loader_tqdm = tqdm(val_loader, desc="Validation")
        for batch in val_loader_tqdm:  `
            bag, label = batch[0][0], batch[1]
            bag = bag.to(device)
            label = label.to(device)
            output = model(bag)
            pred_output = torch.sigmoid(output)

            loss = criterion(output, label.float())
            val_loss += loss.item()
            preds = pred_output[0].round()
            all_preds.extend(preds.cpu().numpy().flatten().tolist())
            all_labels.extend(label.cpu().numpy().flatten().tolist())

    # Calculate validation metrics
    f1 = f1_score(all_labels, all_preds, labels=[0, 1], average='binary')
    recall = recall_score(all_labels, all_preds, labels=[0, 1], average='binary')
    precision = precision_score(all_labels, all_preds, labels=[0, 1], average='binary')

    val_acc = balanced_accuracy_score(all_labels, all_preds)

    # Store metrics
    train_losses.append(train_loss/len(train_loader))
    val_losses.append(val_loss/len(val_loader))
    val_accuracies.append(val_acc)
    val_f1_scores.append(f1)
    val_recall_scores.append(recall)
    val_precision_scores.append(precision)

    # Print validation metrics
    print(f'Validation F1 Score: {f1}')
    print(f'Validation Recall: {recall}')
    print(f'Validation Precision: {precision}')

    val_acc = balanced_accuracy_score(all_labels, all_preds)
    print(f'Validation Loss: {val_loss/len(val_loader)}, Val Balanced Acc: {val_acc}')

# Save model weights
torch.save(model.state_dict(), '/kaggle/working/model_weights.pth')

# Plot training and validation metrics
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.title('Losses')
plt.savefig('/kaggle/working/train_val_loss.png')

plt.subplot(1, 3, 2)
plt.plot(val_accuracies, label='Val Balanced Accuracy')
plt.legend()
plt.title('Balanced Accuracy')
plt.savefig('/kaggle/working/val_balanced_accuracy.png')

plt.subplot(1, 3, 3)
plt.plot(val_f1_scores, label='F1 Score')
plt.plot(val_recall_scores, label='Recall')
plt.plot(val_precision_scores, label='Precision')
plt.legend()
plt.title('Validation Metrics')
plt.savefig('/kaggle/working/validation_metrics.png')

plt.show()

----
----
# <b> Predict the test set </b>

In [None]:
test_dataset = PatientImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset', annotations_file = '/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset/testset_data.csv', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=custom_collate_fn)

model.eval()

predictions_list = []

with torch.no_grad():
    for batch in test_loader:
        bag, _ = batch[0][0], batch[1]
        print(bag.shape)
        bag = bag.to(device)
        output = model(bag)
        predicted_label = output[0].cpu().numpy()
        print(predicted_label.round())

        predictions_list.append(int(predicted_label.round()))
        #predictions_list.extend(output.cpu().numpy())

df = pd.read_csv("/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset/testset_data.csv")

submission_df = pd.DataFrame({
    'Id': df['ID'],
    'Predicted': predictions_list
})

submission_df.to_csv('/kaggle/working/MIL_Embedding_Unbalanced_CUSTOM_ResNet.csv', index=False)
print(submission_df.head())
