In [2]:
import h5py
from transformers import CLIPProcessor, CLIPModel
import torch
from torch import nn
import tqdm 
import torch.utils.data as utils
import pandas as pd

In [3]:
path = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/'
TRAIN_IMAGES_PATH = path + 'train.h5'
VAL_IMAGES_PATH = path + 'val.h5'
TEST_IMAGES_PATH = path + 'test.h5'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available()
                      else 'cpu'
) #'mps' if torch.backends.mps.is_available() and torch.backends.mps.is_built() 
scaler = torch.amp.GradScaler('cuda')

In [5]:
'''
class CustomDataset(utils.Dataset):
    def __init__(self, file):
        self.pixel_values = []
        self.labels = [] if file != TEST_IMAGES_PATH else None
        self.processor = CLIPProcessor.from_pretrained("vinid/plip")

        with h5py.File(file, 'r') as dataset:
            for key in dataset.keys():
                image = dataset[key+'/img'][:]
                inputs = self.processor(
                    images=image,
                    return_tensors="pt",
                    do_rescale=False,
                    padding=True
                )
                self.pixel_values.append(inputs["pixel_values"].squeeze(0)) 
                
                if file != TEST_IMAGES_PATH:
                    label = dataset[key+'/label'][()]
                    self.labels.append(label)
        
    def __len__(self):
        return len(self.pixel_values) 
    
    def __getitem__(self, idx):
        if self.labels is not None: 
            return self.pixel_values[idx], torch.tensor(self.labels[idx], dtype=torch.float32)
        return self.pixel_values[idx] 
'''

In [None]:
class CustomDataset(utils.Dataset):
    def __init__(self, file):
        self.images = []
        self.labels = [] if file != TEST_IMAGES_PATH else None
        with h5py.File(file, 'r') as dataset:
            for key in dataset.keys():
                image = dataset[key+'/img'][:]
                self.images.append(image)
                if file != TEST_IMAGES_PATH:
                    label = dataset[key+'/label'][()]            
                    self.labels.append(label)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if self.labels is not None:
            return torch.from_numpy(self.images[idx]).to(torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)
        return torch.from_numpy(self.images[idx]).to(torch.float32)

In [6]:
def init_model(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

In [7]:
'''
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = device 
        
        self.vision_model = CLIPModel.from_pretrained("vinid/plip").vision_model
        
        for param in self.vision_model.parameters():
            param.requires_grad = False
        self.vision_model.to(self.device)

        self.classifier = nn.Sequential(
            nn.Linear(768, 384), 
            nn.ReLU(),
            nn.Linear(384, 192), 
            nn.ReLU(),
            nn.Linear(192, 48), 
            nn.ReLU(),
            nn.Linear(48, 1),
            nn.Sigmoid()
        )
        
        self.apply(init_model)
        self.to(self.device)
            
    def forward(self, pixel_values):
                
        with torch.no_grad():
            embedding = self.vision_model(pixel_values=pixel_values).last_hidden_state[:, 0, :]
        
        return self.classifier(embedding)
'''

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = device 
        
        self.processor = CLIPProcessor.from_pretrained("vinid/plip")
        self.vision_model = CLIPModel.from_pretrained("vinid/plip").vision_model
        
        for param in self.vision_model.parameters():
            param.requires_grad = False
        self.vision_model.to(self.device)

        self.classifier = nn.Sequential(
            nn.Linear(768, 384), 
            nn.ReLU(),
            nn.Linear(384, 192), 
            nn.ReLU(),
            nn.Linear(192, 48), 
            nn.ReLU(),
            nn.Linear(48, 1),
            nn.Sigmoid()
        )
        
        self.apply(init_model)
        self.to(self.device)
            
    def forward(self, image):
        inputs = self.processor(
            images=image, 
            return_tensors="pt", 
            padding=True, 
            do_rescale=False
        )
        
        pixel_values = inputs['pixel_values'].to(self.device, dtype=torch.float32)
        
        with torch.no_grad():
            embedding = self.vision_model(pixel_values=pixel_values).last_hidden_state[:, 0, :]
        
        return self.classifier(embedding)

In [None]:
train_dataset = CustomDataset(TRAIN_IMAGES_PATH) 

In [None]:
val_dataset = CustomDataset(VAL_IMAGES_PATH)

In [None]:
test_dataset = CustomDataset(TEST_IMAGES_PATH)

In [None]:
train_batch_size, val_batch_size, test_batch_size = 1000, 680, 850

train_loader = utils.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = utils.DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True, num_workers=6, pin_memory=True)
test_loader = utils.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
model = Model()
model.to(device)
loss_fn = nn.BCELoss()
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=1e-4, weight_decay=1e-4)

In [None]:
def calculate_accuracy(predictions, labels):
    binary_preds = (predictions >= 0.5).float()
    correct = (binary_preds == labels).float()
    accuracy = correct.mean() * 100
    return accuracy

In [None]:
def use(epochs):
    train_accs, val_accs = [], []
    
    for epoch in tqdm.tqdm(range(epochs), desc='Epochs'):
        model.train()
        total_acc = 0.0
        
        for pixel_values, labels in tqdm.tqdm(train_loader, desc=f'Train {epoch+1}', leave=True):
            pixel_values, labels = pixel_values.to(device), labels.to(device)
            
            optimizer.zero_grad()

            
            with torch.amp.autocast('cuda'):

                y_pred = model(pixel_values).squeeze()
                loss = loss_fn(y_pred, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            acc = calculate_accuracy(y_pred, labels)
            total_acc += acc

        train_acc = total_acc / (len(train_loader))
        train_accs.append(train_acc)


        model.eval()
        total_val_acc = 0.0
        
        with torch.no_grad():
            for pixel_values, labels in tqdm.tqdm(val_loader, desc=f'Val {epoch+1}', leave=True):
                pixel_values, labels = pixel_values.to(device), labels.to(device)
                
                with torch.amp.autocast('cuda'):
                    val_pred = model(pixel_values).squeeze()
                
                val_acc = calculate_accuracy(val_pred, labels)
                total_val_acc += val_acc

        val_acc = total_val_acc / len(val_loader)
        val_accs.append(val_acc)

        print(f"Epoch{epoch+1}/{epochs} - Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")


    test_predictions = []
    
    with h5py.File(TEST_IMAGES_PATH, 'r') as f:
        idx = list(f.keys())
    
    model.eval()
    
    with torch.no_grad():
        for pixel_values in tqdm.tqdm(test_loader, desc='Testing'):
            pixel_values = pixel_values.to(device)
            
            with torch.amp.autocast('cuda'):
                test_pred = (model(pixel_values).squeeze() > 0.5).float()
            test_predictions.append(test_pred)
            
        test = list(torch.cat(test_predictions).cpu().numpy())

    return list(zip(idx, test)) 

In [None]:
result = use(1)

In [None]:
df = pd.DataFrame(result, columns=['ID', 'Pred'])
df.to_csv('submission.csv', index = False)