In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install git+https://github.com/Mahmoodlab/CONCH.git

In [None]:
import os
from pathlib import Path
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image

from conch.open_clip_custom import create_model_from_pretrained

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# transform = transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.RandomRotation(90),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
#     transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])


# Dataset paths
train_data_path = "/kaggle/input/train-tcga-coad-msi-mss/tcga_coad_msi_mss/train"
val_data_path = "/kaggle/input/train-tcga-coad-msi-mss/tcga_coad_msi_mss/val"

# Datasets
train_dataset = ImageFolder(train_data_path, transform=transform)
val_dataset = ImageFolder(val_data_path, transform=transform)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Print class mapping
if hasattr(train_dataloader.dataset, 'class_to_idx'):
    idx_to_class = {v: k for k, v in train_dataloader.dataset.class_to_idx.items()}
    print(idx_to_class)
else:
    raise ValueError('Dataset does not have class_to_idx attribute')

print("Number of training samples: ", len(train_dataloader.dataset))
print("Number of validation samples: ", len(val_dataloader.dataset))


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from conch.open_clip_custom import create_model_from_pretrained

# Freeze the pre-trained layers
for param in model.parameters():
    param.requires_grad = False

# Add a classification head
num_classes = 2
visual_output_dim = 512  # Update based on the actual output dimension from the visual trunk
model.classification_head = nn.Linear(visual_output_dim, num_classes).to(device)

# Unfreeze the classification head
for param in model.classification_head.parameters():
    param.requires_grad = True

# Define the criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classification_head.parameters(), lr=1e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
# Training Loop
num_epochs = 50
losses = []  # List to store losses
epoch_losses = []  # List to store average losses per epoch

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i, data in enumerate(train_dataloader):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        # Forward pass
        image_features, _ = model.visual(inputs)
        outputs = model.classification_head(image_features)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 100 == 99:    # Print every 100 batches
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
            losses.append((epoch + 1, running_loss / 100))  # Store epoch and average loss
            running_loss = 0.0

    # Calculate average loss per epoch and store
    epoch_loss = running_loss / len(train_dataloader)
    epoch_losses.append((epoch + 1, epoch_loss))

print("Finished Training")


In [None]:
# Plotting losses per batch
import matplotlib.pyplot as plt

epochs, batch_losses = zip(*losses)
plt.plot(epochs, batch_losses, label='Batch Loss')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('Training Loss per Batch')
plt.legend()
plt.show()

# Plotting average losses per epoch
epochs, avg_losses = zip(*epoch_losses)
plt.plot(epochs, avg_losses, label='Average Epoch Loss')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('Training Loss per Epoch')
plt.legend()
plt.show()

In [None]:
# Evaluation
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in val_dataloader:
        images, labels = data[0].to(device), data[1].to(device)
        image_features, _ = model.visual(images)
        outputs = model.classification_head(image_features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')


In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in train_dataloader:
        images, labels = data[0].to(device), data[1].to(device)
        image_features, _ = model.visual(images)
        outputs = model.classification_head(image_features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Training Accuracy: {100 * correct / total:.2f}%')

> Test Accuracy: 81.99% --> convergence at 18 epochs 

> Training Accuracy: 90.80%

> 72.58% ---> only trained the classification head and evaluated it  (10 epochs)

> 72.84% ---> only trained the classification head and evaluated it  (30 epochs)

> 65.23% ---> trained model's hyperparams and evaluated classification head 

> 68.07% ---> only trained head and changed transformations, evaluated head

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from conch.open_clip_custom import create_model_from_pretrained

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset paths
train_data_path = "/kaggle/input/train-tcga-coad-msi-mss/tcga_coad_msi_mss/train"
val_data_path = "/kaggle/input/train-tcga-coad-msi-mss/tcga_coad_msi_mss/val"

# Datasets
train_dataset = ImageFolder(train_data_path, transform=transform)
val_dataset = ImageFolder(val_data_path, transform=transform)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Print class mapping
if hasattr(train_dataloader.dataset, 'class_to_idx'):
    idx_to_class = {v: k for k, v in train_dataloader.dataset.class_to_idx.items()}
    print(idx_to_class)
else:
    raise ValueError('Dataset does not have class_to_idx attribute')

print("Number of training samples: ", len(train_dataloader.dataset))
print("Number of validation samples: ", len(val_dataloader.dataset))

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint_path = '/kaggle/input/pytorch-model/pytorch_model.bin'
model, _ = create_model_from_pretrained(model_cfg='conch_ViT-B-16', checkpoint_path=checkpoint_path, device=device)
_ = model.eval()

# Custom Model with Additional Layers in the Classification Head
class CustomModel(nn.Module):
    def __init__(self, original_model, visual_output_dim, num_classes):
        super(CustomModel, self).__init__()
        self.visual = original_model.visual
        self.classification_head = nn.Sequential(
            nn.Linear(visual_output_dim, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x, _ = self.visual(x)
        x = self.classification_head(x)
        return x

# Instantiate the custom model
model = CustomModel(original_model=model, visual_output_dim=512, num_classes=2).to(device)

# Freeze the pre-trained layers
for param in model.visual.parameters():
    param.requires_grad = False

# Define the criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classification_head.parameters(), lr=1e-4)

# Early stopping parameters
patience = 10
min_delta = 0.001
best_loss = float('inf')
patience_counter = 0
prev_best_weights_path = None

# Training Loop with Early Stopping
num_epochs = 50
losses = []  # List to store losses
epoch_losses = []  # List to store average losses per epoch
val_losses = []  # List to store validation losses

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i, data in enumerate(train_dataloader):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 100 == 99:    # Print every 100 batches
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
            losses.append((epoch + 1, running_loss / 100))  # Store epoch and average loss
            running_loss = 0.0

    # Calculate average loss per epoch and store
    epoch_loss = running_loss / len(train_dataloader)
    epoch_losses.append((epoch + 1, epoch_loss))

    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in val_dataloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_dataloader)
    val_losses.append((epoch + 1, val_loss))
    print(f"Validation loss after epoch {epoch + 1}: {val_loss:.3f}")

    # Save the model if validation loss has decreased
    if val_loss < best_loss - min_delta:
        best_loss = val_loss
        patience_counter = 0
        best_weights_path = f'/kaggle/working/custom_model_weights_epoch_{epoch + 1}.pth'
        
        torch.save(model.state_dict(), best_weights_path)
        print(f"Model weights saved for epoch {epoch + 1}")
        
        # Delete the previous best weights if they exist
        if prev_best_weights_path is not None:
            os.remove(prev_best_weights_path)
            print(f"Deleted previous best weights: {prev_best_weights_path}")
        
        prev_best_weights_path = best_weights_path
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch + 1}")
        break

print("Finished Training")

# Plotting the loss curve
import matplotlib.pyplot as plt

epochs = [x[0] for x in losses]
loss_values = [x[1] for x in losses]

plt.plot(epochs, loss_values, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.show()

val_epochs = [x[0] for x in val_losses]
val_loss_values = [x[1] for x in val_losses]

plt.plot(val_epochs, val_loss_values, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss Over Time')
plt.legend()
plt.show()


In [None]:
torch.save(model.state_dict(), 'custom_model_weights.pth')