In [6]:
import os
import random
import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import warnings
warnings.filterwarnings("ignore")



# ----------------------------
# ImageNet transforms
# ----------------------------
imagenet_transform = T.Compose([
    T.ToTensor(),  # H,W,C -> C,H,W
    T.Resize((224, 224)),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# ----------------------------
# Training dataset: single sweep per sample
# ----------------------------
class SweepDataset(Dataset):
    def __init__(self, csv_path, transform=None, load_nifti=True):
        self.df = pd.read_csv(csv_path)
        self.transform = transform
        self.load_nifti = load_nifti
        self.sweep_cols = [c for c in self.df.columns if c.startswith('path_nifti')]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = random.choice(row[self.sweep_cols])
        
        if self.load_nifti:
            img = nib.load(path).get_fdata().astype(np.float32)  # W,H,C=1,Frames

            # print("Load nifti", img.shape)

            # Get the number of frames
            num_frames = img.shape[-1]
            target_frames = 16
            
            if num_frames >= target_frames:
                # Sample 16 frames uniformly
                indices = np.linspace(0, num_frames - 1, target_frames, dtype=int)
                sampled_img = img[..., indices]
            else:
                # Less than 16 frames, repeat frames to make 16
                repeat_factor = int(np.ceil(target_frames / num_frames))
                repeated_img = np.tile(img, (1, 1, 1, repeat_factor))  # repeat along frames
                sampled_img = repeated_img[..., :target_frames]  # take exactly 16 frames
            img = sampled_img
        else:
            img = path

        # Use all frames
        frames = []
        for f in range(img.shape[-1]):
            frame = np.repeat(img[:,:,:,f], 3, axis=2)  # 3 channels
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        frames = torch.stack(frames, dim=0)  # (T, C, H, W)

        label = torch.tensor(row['ga'], dtype=torch.float32)
        return frames, label


# ----------------------------
# Validation/Test dataset: multiple sweeps per sample
# ----------------------------
class SweepEvalDataset(Dataset):
    def __init__(self, csv_path, n_sweeps=None, transform=None, load_nifti=True):
        self.df = pd.read_csv(csv_path)
        self.transform = transform
        self.load_nifti = load_nifti
        self.sweep_cols = [c for c in self.df.columns if c.startswith('path_nifti')]
        self.n_sweeps = n_sweeps

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sweeps = row[self.sweep_cols].tolist()
        if self.n_sweeps:
            sweeps = sweeps[:self.n_sweeps]

        all_sweeps = []
        for path in sweeps:
            if self.load_nifti:
                img = nib.load(path).get_fdata().astype(np.float32)

                target_frames = 16
                num_frames = img.shape[-1]
                if num_frames >= target_frames:
                    # Sample 16 frames uniformly
                    indices = np.linspace(0, num_frames - 1, target_frames, dtype=int)
                    sampled_img = img[..., indices]
                else:
                    # Less than 16 frames, repeat frames to make 16
                    repeat_factor = int(np.ceil(target_frames / num_frames))
                    repeated_img = np.tile(img, (1, 1, 1, repeat_factor))  # repeat along frames
                    sampled_img = repeated_img[..., :target_frames]  # take exactly 16 frames
                img = sampled_img


                
                frames = []
                for f in range(img.shape[-1]):
                    frame = np.repeat(img[:,:,:,f], 3, axis=2)  # 3 channels
                    if self.transform:
                        frame = self.transform(frame)
                    frames.append(frame)
                frames = torch.stack(frames, dim=0)  # (T, C, H, W)
            else:
                frames = path
            all_sweeps.append(frames)

        all_sweeps = torch.stack(all_sweeps, dim=0)  # (num_sweeps, T, C, H, W)
        label = torch.tensor(row['ga'], dtype=torch.float32)
        return all_sweeps, label


# ----------------------------
# Model Definition
# ----------------------------
class WeightedAverageAttention(nn.Module):
    def __init__(self, feature_dim=512, reduced_dim=128):
        super().__init__()
        self.W = nn.Linear(feature_dim, 64)
        self.V = nn.Linear(64, 1)
        self.Q = nn.Linear(feature_dim, reduced_dim)

    def forward(self, features):
        attn_scores = self.V(torch.tanh(self.W(features)))  # (B,T,1)
        attn_weights = F.softmax(attn_scores, dim=1)        # (B,T,1)
        reduced_features = self.Q(features)                 # (B,T,reduced_dim)
        weighted_sum = torch.sum(attn_weights * reduced_features, dim=1)
        return weighted_sum, attn_weights.squeeze(-1)


class NEJMbaseline(nn.Module):
    def __init__(self, reduced_dim=128, fine_tune_backbone=True, pretrained=True):
        super().__init__()
        resnet = models.resnet18(pretrained=pretrained)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        self.feature_dim = 512
        if not fine_tune_backbone:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        self.attention = WeightedAverageAttention(feature_dim=self.feature_dim, reduced_dim=reduced_dim)
        self.fc = nn.Linear(reduced_dim, 1)

    def forward(self, x):
        """
        x: (B, T, C, H, W)
        """
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        features = self.feature_extractor(x)  # (B*T, 2048,1,1)
        features = features.view(B, T, self.feature_dim)
        aggregated, attn_weights = self.attention(features)
        output = self.fc(aggregated)
        return output, attn_weights


# ----------------------------
# Training + Validation
# ----------------------------
def train_and_validate(train_csv, val_csv, epochs=100, batch_size=8, n_sweeps_val=8, save_path = 'best_model.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = SweepDataset(train_csv, transform=imagenet_transform)
    val_dataset = SweepEvalDataset(val_csv, n_sweeps=n_sweeps_val, transform=imagenet_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    model = NEJMbaseline().to(device)
    criterion = nn.L1Loss()  # MAE
    optimizer = Adam(model.parameters(), lr=1e-4)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        # ---------------- Training ----------------
        model.train()
        train_loss = 0
        for frames, labels in train_loader:
            frames = frames.to(device)  # (B, T, C, H, W)
            labels = labels.float().to(device).unsqueeze(1)
            optimizer.zero_grad()
            outputs, _ = model(frames)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * frames.size(0)
        train_loss /= len(train_loader.dataset)
        print(f"Epoch {epoch+1} | Train MAE: {train_loss:.4f}")


        
        
        # ---------------- Validation ----------------
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for sweeps, labels in val_loader:
                sweeps = sweeps.to(device)  # (B, num_sweeps, T, C, H, W)
                B, S, T, C, H, W = sweeps.shape
                sweeps = sweeps.view(B, S*T, C, H, W)  # combine sweeps and frames
                labels = labels.float().to(device).unsqueeze(1)
                outputs, _ = model(sweeps)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * B
        val_loss /= len(val_loader.dataset)
        print(f"Epoch {epoch+1} | Val MAE: {val_loss:.4f}")



        # ---------------- Save best model ----------------
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"✅ Saving new best model (Val MAE: {val_loss:.4f})")
            torch.save({
                'model_state_dict': model.state_dict(),
                'model_architecture': model,
                'epoch': epoch + 1,
                'val_loss': val_loss
            }, save_path)

    print(f"Training completed. Best Val MAE: {best_val_loss:.4f}")



In [9]:
val_csv = '/mnt/Data/hackathon/data_splits/final_valid.csv'

def infer_test(test_csv, model_path='best_model.pth', n_sweeps_test=8, output_csv='test_predictions.csv'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load saved model
    checkpoint = torch.load(model_path, map_location=device, weights_only = False)
    model = checkpoint['model_architecture']
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    # Prepare test data
    test_df = pd.read_csv(test_csv)
    test_dataset = SweepEvalDataset(test_csv, n_sweeps=n_sweeps_test, transform=imagenet_transform)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

    predictions = []
    study_ids = []

    with torch.no_grad():
        for i, (sweeps, _) in enumerate(test_loader):
            sweeps = sweeps.to(device)
            B, S, T, C, H, W = sweeps.shape
            sweeps = sweeps.view(B, S*T, C, H, W)
            outputs, _ = model(sweeps)
            preds = outputs.squeeze(1).cpu().numpy()
            predictions.extend(preds)

            start_idx = i * test_loader.batch_size
            end_idx = min(start_idx + B, len(test_df))
            study_ids.extend(test_df.iloc[start_idx:end_idx]['study_id'].tolist())

    # Save predictions
    result_df = pd.DataFrame({'study_id': study_ids, 'ga': predictions})
    result_df.to_csv(output_csv, index=False)
    print(f"✅ Saved predictions to {output_csv}")

In [10]:
# Inference
test_csv = '/mnt/Data/hackathon/data_splits/final_test.csv'
infer_test(test_csv, model_path='best_model.pth', n_sweeps_test=8, output_csv='predictions.csv')


✅ Saved predictions to predictions.csv


In [None]:
# ----------------------------
# Example usage
# ----------------------------
train_csv = '/mnt/Data/hackathon/data_splits/final_train.csv'
val_csv = '/mnt/Data/hackathon/data_splits/final_valid.csv'

train_and_validate(train_csv, val_csv, epochs=100, batch_size=16, n_sweeps_val=8, save_path='best_model.pth')


In [12]:
pd.read_csv('/home/test/hackathon/predictions.csv')

Unnamed: 0,study_id,ga
0,KA-AC-190-1,59.716717
1,KA-AC-111-1,80.689354
2,KA-AC-107-1,65.882286
3,KA-AC-76-1,81.253654
4,KA-AD-68-1,73.34454
5,KA-AD-204-1,65.04141
6,KA-AC-213-1,77.429756
7,KA-AD-193-1,63.790897
8,KA-AD-134-1,73.665955
9,KA-AC-84-1,47.635925
