In [5]:
import h5py
from transformers import CLIPProcessor, CLIPModel
import torch
from torch import nn
import torch.optim as optim
import tqdm 
import torch.utils.data as utils
import pandas as pd

In [None]:
path = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/'
TRAIN_IMAGES_PATH = path + 'train.h5'
VAL_IMAGES_PATH = path + 'val.h5'
TEST_IMAGES_PATH = path + 'test.h5'

In [7]:
device = torch.device(
    'mps' if torch.backends.mps.is_available() and torch.backends.mps.is_built() 
    else 'cuda' if torch.cuda.is_available() 
    else 'cpu'
)

class CustomDataset(utils.Dataset):
    def __init__(self, file):
        self.images, self.labels = [], []
        self.file = file
        with h5py.File(file, 'r') as dataset:
            for key in dataset.keys():
                image = dataset[key+'/img'][:]
                self.images.append(image)
                if file != TEST_IMAGES_PATH:
                    label = dataset[key+'/label'][()]            
                    self.labels.append(label)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = torch.from_numpy(self.images[idx]).to(torch.float32)
        if self.file != TEST_IMAGES_PATH:
            label = torch.tensor(self.labels[idx], dtype=torch.float32)
            return image, label
        return image


def init_model(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = device 
        
        self.processor = CLIPProcessor.from_pretrained("vinid/plip")
        self.vision_model = CLIPModel.from_pretrained("vinid/plip").vision_model
        
        for param in self.vision_model.parameters():
            param.requires_grad = False
        self.vision_model.to(self.device)

        self.classifier = nn.Sequential(
            nn.Linear(768, 384), 
            nn.ReLU(),
            nn.Linear(384, 192), 
            nn.ReLU(),
            nn.Linear(192, 48), 
            nn.ReLU(),
            nn.Linear(48, 1),
            nn.Sigmoid()
        )
        
        self.apply(init_model)
        self.to(self.device)
            
    def forward(self, image):
        inputs = self.processor(
            images=image, 
            return_tensors="pt", 
            padding=True, 
            do_rescale=False
        )
        
        pixel_values = inputs['pixel_values'].to(self.device, dtype=torch.float32)
        
        with torch.no_grad():
            embedding = self.vision_model(pixel_values=pixel_values).last_hidden_state[:, 0, :]
        
        return self.classifier(embedding)


In [20]:
train_batch_size, val_batch_size, test_batch_size = 100, 10, 10

train_dataset = CustomDataset(TRAIN_IMAGES_PATH)
val_dataset = CustomDataset(VAL_IMAGES_PATH)
test_dataset = CustomDataset(TEST_IMAGES_PATH)

train_loader = utils.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = utils.DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True)
test_loader = utils.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

In [8]:
model = Model()
model.to(device)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [23]:
def calculate_accuracy(predictions, labels):
    binary_preds = (predictions >= 0.5).float()
    correct = (binary_preds == labels).float()
    accuracy = correct.mean() * 100
    return accuracy

In [31]:
def use(epochs):
    train_accs, val_accs = [], []
    best_acc, total_acc, total_val_acc = 0.0, 0.0, 0.0
    
    for epoch in tqdm.tqdm(range(epochs)):
        model.train()
        train_correct = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            y_pred = model(images).squeeze()
            loss = loss_fn(y_pred, labels)
            loss.backward()
            optimizer.step()

            acc = calculate_accuracy(y_pred, labels)
            total_acc += acc

        train_acc = total_acc / (len(train_loader))
        train_accs.append(train_acc)


        model.eval()
        val_correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                val_pred = model(images).squeeze()
                val_acc = calculate_accuracy(val_pred, labels)
                total_val_acc += val_acc

        val_acc = total_val_acc / len(len(val_loader))
        val_accs.append(val_acc)

    test_predictions = []
    with h5py.File(TEST_IMAGES_PATH, 'r') as f:
        idx = list(f.keys())
    model.eval()
    with torch.no_grad():
        for images in test_loader:
            images = images.to(device)
            test_pred = (model(images).squeeze() > 0.5).float()
            test_predictions.append(test_pred)
        test = list(torch.cat(test_predictions).cpu().numpy())

    return list(zip(idx, test)) 

In [None]:
result = use(1)

In [None]:
df = pd.DataFrame(result, columns=['ID', 'Pred'])
df.to_csv('submission.csv', index = False)