In [1]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

In [2]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import os
import pydicom
from torchvision import transforms
from collections import defaultdict
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
import timm

class MRIDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
        self.bags = defaultdict(list)
        for _, row in self.data.iterrows():
            self.bags[row['SeriesInstanceUID']].append(row)
    
    def __len__(self):
        return len(self.bags)
    
    def __getitem__(self, idx):
        bag_id = list(self.bags.keys())[idx]
        bag = self.bags[bag_id]
        
        images = []
        for item in bag:
            img_path = os.path.join(self.root_dir, item['SeriesInstanceUID'])
            for img_file in os.listdir(img_path):
                if img_file.endswith('.dcm'):
                    dcm = pydicom.dcmread(os.path.join(img_path, img_file))
                    img = dcm.pixel_array.astype(np.float32)
                    img = (img - img.min()) / (img.max() - img.min())
                    if self.transform:
                        img = self.transform(img)
                    images.append(img)
        
        label = bag[0]['prediction']
        body_part = bag[0]['Body Part Examined']
        series_desc = bag[0]['Series Description']
        
        return torch.stack(images), label, body_part, series_desc

class GatedAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GatedAttention, self).__init__()
        self.attention_V = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh()
        )
        self.attention_U = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.attention_weights = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        V = self.attention_V(x)
        U = self.attention_U(x)
        attention = self.attention_weights(V * U)
        return attention

class TransformerMIL(nn.Module):
    def __init__(self, num_classes=2):
        super(TransformerMIL, self).__init__()
        
        self.feature_extractor = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, in_chans=1)
        self.feature_dim = self.feature_extractor.num_features
        
        self.attention = GatedAttention(self.feature_dim, 256)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        x = x.squeeze(0)
        H = self.feature_extractor(x)
        
        A = self.attention(H)
        A = torch.transpose(A, 1, 0)
        A = nn.functional.softmax(A, dim=1)
        
        M = torch.mm(A, H)
        Y_prob = self.classifier(M)
        Y_hat = torch.argmax(Y_prob, dim=1)
        
        return Y_prob, Y_hat, A

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

def train(model, train_loader, criterion, optimizer, device, scaler):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, targets, _, _) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        
        with autocast():
            outputs, _, _ = model(data)
            loss = criterion(outputs, targets.float().unsqueeze(1))
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % 4 == 0:  # Gradient accumulation
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        running_loss += loss.item()
        predicted = (outputs > 0.5).float()
        total += targets.size(0)
        correct += (predicted == targets.unsqueeze(1)).sum().item()
    
    accuracy = 100. * correct / total
    return running_loss / len(train_loader), accuracy

def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (data, targets, _, _) in enumerate(test_loader):
            data, targets = data.to(device), targets.to(device)
            
            with autocast():
                outputs, _, _ = model(data)
                loss = criterion(outputs, targets.float().unsqueeze(1))
            
            running_loss += loss.item()
            predicted = (outputs > 0.5).float()
            total += targets.size(0)
            correct += (predicted == targets.unsqueeze(1)).sum().item()
    
    accuracy = 100. * correct / total
    return running_loss / len(test_loader), accuracy

batch_size = 1
num_epochs = 50
learning_rate = 0.0001
csv_file = '/kaggle/input/editedcsv/FinalMetaData.csv'
root_dir = '/kaggle/input/iaaa-mri-challenge/data/'

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = MRIDataset(csv_file, root_dir, transform=transform)
train_set, test_set = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerMIL().to(device)

criterion = FocalLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = GradScaler()

for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, scaler)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    print('-----------------------------')

torch.save(model.state_dict(), 'transformer_mil_model.pth')

US1_J2KR.dcm:   0%|          | 38.0/154k [00:00<01:18, 1.95kB/s]
MR-SIEMENS-DICOM-WithOverlays.dcm:   0%|          | 125/511k [00:00<02:16, 3.74kB/s]
OBXXXX1A.dcm:   0%|          | 119/486k [00:00<01:45, 4.62kB/s]
US1_UNCR.dcm:   0%|          | 226/923k [00:00<02:23, 6.44kB/s]
color3d_jpeg_baseline.dcm:   0%|          | 1.50k/6.14M [00:00<06:23, 16.0kB/s]


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

  scaler = GradScaler()
  with autocast():


ValueError: Target size (torch.Size([1, 1])) must be the same as input size (torch.Size([1, 2]))