In [None]:
import numpy as np
import pandas as pd
import openslide
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import os

import warnings
warnings.filterwarnings('ignore')


In [None]:
def percentage(mask):
    return (np.sum(mask > 0) / mask.size) * 100

def extract_patches(im_slide, ms_slide, level, size_mask, num_patches_needed):
    f = int(ms_slide.level_downsamples[level])
    size_scale = im_slide.level_dimensions[level][0] // ms_slide.level_dimensions[level][0]
    coord_scale = im_slide.level_dimensions[0][0] // ms_slide.level_dimensions[level][0]
    size_image = (size_mask[0] * size_scale, size_mask[1] * size_scale)
    
    ms_width, ms_height = ms_slide.level_dimensions[level]
    
    l = [(x_ms, y_ms) for y_ms in range(0, ms_height, size_mask[1]) 
         for x_ms in range(0, ms_width, size_mask[0])]
    
    count, used_indices = 0, []
    random.seed(42)
    image_patches = []

    while count < num_patches_needed:
        index = random.randint(0, len(l) - 1)
        if index not in used_indices:
            used_indices.append(index)
            x_ms, y_ms = l[index]
            x_im, y_im = x_ms * coord_scale, y_ms * coord_scale
            mask_patch = ms_slide.read_region((x_ms * f, y_ms * f), level, size_mask).convert("L")
            image_patch = im_slide.read_region((x_im, y_im), level, size_image).convert("RGB")
   
            if percentage(np.array(mask_patch)) > 60:
                image_patches.append(np.array(image_patch))
                count += 1
                if count == num_patches_needed:
                    break

    return image_patches


In [None]:

# DATASET
class PatchDataset(Dataset):
    def __init__(self, images_dir, masks_dir, csv_file, num_patches_per_image, level, size_mask):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.df = pd.read_csv(csv_file)
        self.num_patches_per_image = num_patches_per_image
        self.level = level
        self.size_mask = size_mask
        self.image_list = sorted(os.listdir(images_dir))
    
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        image_file = self.image_list[idx]
        impath = os.path.join(self.images_dir, image_file)
        mspath = os.path.join(self.masks_dir, image_file.replace('.tif', '_tissue.tif'))
        
        im_slide = openslide.OpenSlide(impath)
        ms_slide = openslide.OpenSlide(mspath)
        
        image_patches = extract_patches(im_slide, ms_slide, self.level, self.size_mask, self.num_patches_per_image)
        image_patches = np.array(image_patches, dtype=np.float32) / 255.0  # Normalize to [0, 1]
        
        case_id_to_find = image_file
        filtered_row = self.df.loc[self.df['case_id'] == case_id_to_find].iloc[0]
#         event = filtered_row["event"]
        years = filtered_row["follow_up_years"]
        
        labels = np.array([[ years]] * self.num_patches_per_image, dtype=np.float32)

        return torch.from_numpy(image_patches.transpose(0, 3, 1, 2)), torch.from_numpy(labels)


In [None]:
import torchvision.models as models

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    def __init__(self, num_patches, input_shape):
        super(ResNetTransformer, self).__init__()
        self.num_patches = num_patches
        self.input_shape = input_shape
        
        # ResNet backbone
        self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
        
        # Final layers
        self.fc = nn.Linear(512, 64)
        self.output = nn.Linear(64, 1)  # Changed to output 1 dimension instead of 2

    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        layers = []
        layers.append(ResNetBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResNetBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(-1, self.input_shape[0], self.input_shape[1], self.input_shape[2])
        
        # ResNet backbone
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(batch_size, self.num_patches, -1)
        
        # Transformer Encoder
        x = self.transformer_encoder(x)
        
        # Final layers
        x = self.fc(x)
        x = self.relu(x)
        x = self.output(x)
        
        return x

In [None]:
import gc

In [None]:
def custom_loss(y_pred, y_true):
    mse = nn.MSELoss()
    event_loss = mse(y_pred[:, :, 0], y_true[:, :, 0])
    return event_loss 

# TRAINING
def train_model(model, train_loader, optimizer, device, num_epochs, checkpoint_dir):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = custom_loss(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            del data , target , output , loss
            gc.collect()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

        if epoch  % 10 == 0:
            torch.save(model.state_dict(), f"{checkpoint_dir}/epoch{epoch:5f}-loss{total_loss:.5f}.pth")
            print(f"Model saved at epoch {epoch + 1}")

In [None]:
BATCH_SIZE = 7
NUM_PATCHES = 33
INPUT_SHAPE = (3, 512, 512)
TOTAL_IMAGES = 55

images_dir = "/kaggle/input/dddddddd/images"
masks_dir = "/kaggle/input/dddddddd/masks"
csv_file = "/kaggle/input/fuckthis/training_labels.csv"
checkpoint_dir = "/kaggle/working"
dataset = PatchDataset(images_dir, masks_dir, csv_file, num_patches_per_image=NUM_PATCHES, level=1, size_mask=(64, 64))
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetTransformer(num_patches=NUM_PATCHES, input_shape=INPUT_SHAPE)
model = nn.DataParallel(model).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for x , y in train_loader:
    print(x.shape , y.shape)
    break
    

In [None]:
num_epochs = 1
train_model(model, train_loader, optimizer, device, num_epochs, checkpoint_dir)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# # Usage
# BATCH_SIZE = 16
# NUM_PATCHES = 4
# INPUT_SHAPE = (3, 512, 512)

# # Create the model
# model = ResNetTransformer(num_patches=NUM_PATCHES, input_shape=INPUT_SHAPE)

# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs")
#     model = nn.DataParallel(model)

# model = model.to(device)

# sample_input = torch.randn(BATCH_SIZE, NUM_PATCHES, *INPUT_SHAPE, device=device)

# output = model(sample_input)

# print(f"Input shape: {sample_input.shape}")
# print(f"Output shape: {output.shape}")

# **INFERENCE**

In [None]:
        impath = os.path.join(self.images_dir, image_file)
        mspath = os.path.join(self.masks_dir, image_file.replace('.tif', '_tissue.tif'))
        
        im_slide = openslide.OpenSlide(impath)
        ms_slide = openslide.OpenSlide(mspath)
        
        image_patches = extract_patches(im_slide, ms_slide, self.level, self.size_mask, self.num_patches_per_image)
        image_patches = np.array(image_patches, dtype=np.float32) / 255.0  # Normalize to [0, 1]
                

        return torch.from_numpy(image_patches.transpose(0, 3, 1, 2))