----
----
# <b> Challenge DLMI </b>
# <b> Multi-Instance Learning Embedding </b>
# <b> Matteo MARENGO | matteo.marengo@ens-paris-saclay.fr</b>
# <b> Manal MEFTHA | manal.meftah@ens-paris-saclay.fr </b>

----
----
# <b> Import Libraries </b>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, recall_score, precision_score

----
----
# <b> Define the device </b>

In [None]:
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

----
----
# <b> Define the Dataset Class </b>

In [None]:
# Dataset class
class PatientImagesDataset(Dataset):
    def __init__(self, root_dir, annotations_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotations_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_folder = os.path.join(self.root_dir, str(self.annotations.iloc[idx, 0]))
        images = [os.path.join(img_folder, file) for file in os.listdir(img_folder) if file.endswith('.jpg')]
        bags = []
        for img_name in images:
            image = Image.open(img_name)
            if self.transform:
                image = self.transform(image)
            bags.append(image)
        bags = torch.stack(bags)
        # second column of the dataframe - the label of the bag
        labels = torch.tensor(int(self.annotations.iloc[idx, 1]))
        return bags, labels

----
----
# <b> Define the CustomResNet </b>

In [None]:
class CustomResNet(nn.Module):
    def __init__(self, K):
        super(CustomResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, K, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(K)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2 = self._make_layer(K, K, stride=1)
        self.conv3 = self._make_layer(K, 2*K, stride=2)
        self.conv4 = self._make_layer(2*K, 4*K, stride=2)
        self.conv5 = self._make_layer(4*K, 8*K, stride=2)


    def _make_layer(self, in_channels, out_channels, stride):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        ]
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        x = x.view(x.size(0), -1)

        return x

----
----
# <b> Define the MIL model </b>

In [None]:
class MILModel(nn.Module):
    def __init__(self, K=64):
        super(MILModel, self).__init__()
        self.feature_extractor = CustomResNet(K)
        self.classifier = nn.Linear(7*7*8*K, 1)

    def forward(self, x):
        embeddings = self.feature_extractor(x)
        embedding = embeddings.mean(dim=0, keepdim=True)
        output = self.classifier(embedding)
        #return output
        return torch.sigmoid(output)

----
----
# <b> Train the model </b>

In [None]:
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor()])

dataset = PatientImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset', annotations_file='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset/trainset_true.csv', transform=transform)
train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.annotations['LABEL'])
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

def custom_collate_fn(batch):
    bags, labels = zip(*batch)
    labels = torch.tensor(labels)
    return bags, labels

train_loader = DataLoader(dataset, batch_size=1, sampler=train_sampler, collate_fn=custom_collate_fn)
val_loader = DataLoader(dataset, batch_size=1, sampler=val_sampler, collate_fn=custom_collate_fn)

model = MILModel(K=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = nn.BCELoss() # not balanced loss for this case

train_losses = []
val_losses = []
val_accuracies = []
val_f1_scores = []
val_recall_scores = []
val_precision_scores = []

num_epochs = 150
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, desc="Training")
    for batch in train_loader_tqdm:
        bag, label = batch[0][0], batch[1]
        bag = bag.to(device)
        #print(bag.shape)
        #print(label)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(bag)
        #print(output)
        loss = criterion(output[0], label.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_loader_tqdm.set_postfix({"Train Loss": f"{loss.item():.4f}"})

    print(f"Training Loss: {train_loss/len(train_loader)}")

    # Validation loop
    model.eval()
    all_preds = []
    all_labels = []
    val_loss = 0.0

    with torch.no_grad():
        val_loader_tqdm = tqdm(val_loader, desc="Validation")
        for batch in val_loader_tqdm:
            bag, label = batch[0][0], batch[1]
            bag = bag.to(device)
            label = label.to(device)
            output = model(bag)
            loss = criterion(output[0], label.float())
            val_loss += loss.item()

            preds = output[0].round()
            all_preds.extend(preds.cpu().numpy().flatten().tolist())
            all_labels.extend(label.cpu().numpy().flatten().tolist())

    f1 = f1_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    val_acc = balanced_accuracy_score(all_labels, all_preds)

    train_losses.append(train_loss/len(train_loader))
    val_losses.append(val_loss/len(val_loader))
    val_accuracies.append(val_acc)
    val_f1_scores.append(f1)
    val_recall_scores.append(recall)
    val_precision_scores.append(precision)


    print(f'Validation F1 Score: {f1}')
    print(f'Validation Recall: {recall}')
    print(f'Validation Precision: {precision}')

    val_acc = balanced_accuracy_score(all_labels, all_preds)
    print(f'Validation Loss: {val_loss/len(val_loader)}, Val Balanced Acc: {val_acc}')

# Save model weights
torch.save(model.state_dict(), '/kaggle/working/model_weights.pth')

# Plot training and validation metrics
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.title('Losses')

plt.subplot(1, 3, 2)
plt.plot(val_accuracies, label='Val Balanced Accuracy')
plt.legend()
plt.title('Balanced Accuracy')

plt.subplot(1, 3, 3)
plt.plot(val_f1_scores, label='F1 Score')
plt.plot(val_recall_scores, label='Recall')
plt.plot(val_precision_scores, label='Precision')
plt.legend()
plt.title('Validation Metrics')

plt.show()

----
----
# <b> Predict the test set and save it as a dataframe </b>

In [None]:
test_dataset = PatientImagesDataset(root_dir='/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset', annotations_file = '/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset/testset_data.csv', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=custom_collate_fn)

model.eval()

predictions_list = []

with torch.no_grad():
    for batch in test_loader:
        bag, _ = batch[0][0], batch[1]
        print(bag.shape)
        bag = bag.to(device)
        output = model(bag)
        predicted_label = int(round(output.cpu().numpy()[0][0]))
        predictions_list.append(predicted_label)

df = pd.read_csv("/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset_data.csv")

submission_df = pd.DataFrame({
    'Id': df['ID'],
    'Predicted': predictions_list
})

submission_df.to_csv('/kaggle/working/custom_Resnet.csv', index=False)
print(submission_df.head())
