In [1]:
!python -V
!pip -V
!python -c "import sys, pkgutil; print('numpy', pkgutil.find_loader('numpy') is not None); print('torch', pkgutil.find_loader('torch') is not None)"
!pip install --upgrade --no-deps timm pylibjpeg pylibjpeg-libjpeg pylibjpeg-openjpeg
!pip install --upgrade --no-deps pylibjpeg==2.1.0 pylibjpeg-libjpeg==2.3.0 pylibjpeg-openjpeg==2.5.0 || true
!pip check || true

Python 3.11.13
pip 24.1.2 from /usr/local/lib/python3.11/dist-packages/pip (python 3.11)
numpy True
torch True
Collecting timm
  Downloading timm-1.0.20-py3-none-any.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.7/61.7 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hCollecting pylibjpeg
  Downloading pylibjpeg-2.1.0-py3-none-any.whl.metadata (7.9 kB)
Collecting pylibjpeg-libjpeg
  Downloading pylibjpeg_libjpeg-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.8 kB)
Collecting pylibjpeg-openjpeg
  Downloading pylibjpeg_openjpeg-2.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.8 kB)
Downloading timm-1.0.20-py3-none-any.whl (2.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading pylibjpeg-2.1.0-py3-none-any.whl (25 kB)
Downloading pylibjpeg_libjpeg-2.3.0-cp311-cp311-manylinux_

In [None]:
import os, sys
os.kill(os.getpid(), 9)

In [20]:
import os
import random
import numpy as np
import pandas as pd
import json
from PIL import Image
import pydicom
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import timm
from tqdm import tqdm
import math

In [16]:
DATA_PATH = "/kaggle/input/rsna-2022-cervical-spine-fracture-detection"
train_df = pd.read_csv(os.path.join(DATA_PATH, "train.csv"))
test_df = pd.read_csv(os.path.join(DATA_PATH, "test.csv"))
sample_sub = pd.read_csv(os.path.join(DATA_PATH, "sample_submission.csv"))
TRAIN_IMG_DIR = os.path.join(DATA_PATH, "train_images")

study_ids = train_df["StudyInstanceUID"].unique()
np.random.seed(33)
np.random.shuffle(study_ids)
split_idx = int(len(study_ids) * 0.8)
train_studies = study_ids[:split_idx]
val_studies = study_ids[split_idx:]
train_df_split = train_df[train_df["StudyInstanceUID"].isin(train_studies)]
val_df_split = train_df[train_df["StudyInstanceUID"].isin(val_studies)]

## ViT + Cervical Dataset Classes

In [17]:
class CervicalSliceDataset(Dataset):
    def __init__(self, df, root, transform=None, num_slices=5):
        self.df = df
        self.root = root
        self.transform = transform
        self.num_slices = num_slices
        self.study_ids = df["StudyInstanceUID"].unique().tolist()
    
    def __len__(self):
        return len(self.study_ids)
    
    def __getitem__(self, idx):
        study = self.study_ids[idx]
        folder = os.path.join(self.root, study)
        files = sorted([f for f in os.listdir(folder) if f.endswith(".dcm")])
        
        if len(files) == 0:
            raise RuntimeError(f"No DICOM in {folder}")
        
        indices = np.linspace(0, len(files)-1, self.num_slices, dtype=int)
        slices = []
        
        for i in indices:
            path = os.path.join(folder, files[i])
            ds = pydicom.dcmread(path)
            try:
                arr = ds.pixel_array
            except Exception:
                ds.decompress()
                arr = ds.pixel_array
            
            if arr.ndim == 3:
                arr = arr[0]
            
            img = Image.fromarray(arr).convert("L")
            if self.transform:
                img = self.transform(img)
            slices.append(img)
        
        img_tensor = torch.mean(torch.stack(slices), dim=0)
        
        row = self.df[self.df["StudyInstanceUID"]==study].iloc[0]
        labels = torch.zeros(8, dtype=torch.float32)
        labels[0] = row["patient_overall"]
        for i in range(1,8):
            labels[i] = row[f"C{i}"]
        
        return img_tensor, labels

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x

class MLP(nn.Module):
    def __init__(self, embed_dim=768, hidden_dim=3072, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=8,
                 embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        cls_output = x[:, 0]
        x = self.head(cls_output)
        return x

train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(0.3),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

train_ds = CervicalSliceDataset(train_df_split, TRAIN_IMG_DIR, transform=train_transforms, num_slices=5)
val_ds = CervicalSliceDataset(val_df_split, TRAIN_IMG_DIR, transform=val_transforms, num_slices=5)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

## Modelo ViT

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=8,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0,
    dropout=0.1
)
model = model.to(device)

pos_counts = train_df_split.iloc[:, 1:9].sum()
neg_counts = len(train_df_split) - pos_counts
pos_weight = (neg_counts / pos_counts).values
pos_weight = torch.tensor(pos_weight, dtype=torch.float32).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15, eta_min=1e-6)

In [None]:
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

def train_epoch(model, loader, optimizer, criterion, device, scaler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for imgs, labels in tqdm(loader, desc="Training"):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        
        if scaler is not None:
            with torch.amp.autocast('cuda'):
                out = model(imgs)
                loss = criterion(out, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(imgs)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * imgs.size(0)
        preds = (torch.sigmoid(out) > 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.numel()

        del imgs, labels, out, loss
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    return running_loss / len(loader.dataset), correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Validation"):
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            if torch.cuda.is_available():
                with torch.amp.autocast('cuda'):
                    out = model(imgs)
                    loss = criterion(out, labels)
            else:
                out = model(imgs)
                loss = criterion(out, labels)
            
            running_loss += loss.item() * imgs.size(0)
            preds = (torch.sigmoid(out) > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.numel()
    
    return running_loss / len(loader.dataset), correct / total

results = []
best_val_acc = 0.0

for epoch in range(1, 16):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    results.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'train_accuracy': train_acc,
        'val_loss': val_loss,
        'val_accuracy': val_acc,
        'lr': optimizer.param_groups[0]['lr']
    })
    
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
          f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}, lr={optimizer.param_groups[0]['lr']:.2e}")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_vit_custom.pth')
        print(f"  -> Saved best model with val_acc={val_acc:.4f}")

results_df = pd.DataFrame(results)
results_df.to_csv('training_results_custom_vit.csv', index=False)
print("\nResultados guardados en 'training_results_custom_vit.csv'")
print(f"\nMejor validación accuracy: {best_val_acc:.4f}")

Training: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Validation: 100%|██████████| 26/26 [00:09<00:00,  2.74it/s]


Epoch 1: train_loss=1.2418, train_acc=0.4803, val_loss=1.2577, val_acc=0.2840, lr=2.97e-04
  -> Saved best model with val_acc=0.2840


Training:  82%|████████▏ | 83/101 [00:40<00:08,  2.19it/s]