In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
import pandas as pd
import timm
import torch.nn.functional as F

# If using your custom data loader, import it (or define your own here)
from data import build_split_dataloaders  # ...existing code...
from model import SwinTransformerClassificationModel  # New import for SwinTransformer

# Hyperparameters
batch_size = 8
learning_rate = 0.001
num_epochs = 10
log_dir = "runs/experiment_2"

# Data paths
root_dir = os.path.join("K:", "rsna-breast-cancer-detection")
csv_path = os.path.join(root_dir, "train.csv")
root_dir = os.path.join(root_dir, "train_images_cropped")

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    # RESIZE TO 224x224
    transforms.Resize((224, 224)),
])

# Build dataloaders (Assumes build_split_dataloaders is defined in data.py)
train_loader, val_loader, test_loader = build_split_dataloaders(
    csv_path, root_dir, batch_size=batch_size, transform=transform, train=True, val_ratio=0.2, test_ratio=0.1
)

print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

Train batches: 2301, Validation batches: 658


In [None]:
# Initialize model, criterion, optimizer, and TensorBoard writer

# Use SwinTransformer instead of resnet18
num_classes = 3

model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=3, global_pool='avg')
        
old_proj = model.patch_embed.proj
model.patch_embed.proj = nn.Conv2d(
    in_channels=1,
    out_channels=old_proj.out_channels,
    kernel_size=old_proj.kernel_size,
    stride=old_proj.stride,
    padding=old_proj.padding,
    bias=old_proj.bias is not None
)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

writer = SummaryWriter(log_dir=log_dir)

print(f"Using device: {device}")

Using device: cuda


In [None]:
# Training loop with extra logging
global_step = 0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    epoch_start = time.time()
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device).long()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        # Apply pooling if model output is spatial (e.g. [batch, channels, H, W])
        
        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        # Compute predictions
        _, preds = torch.max(outputs, 1)
        correct_train += torch.sum(preds == labels).item()
        total_train += labels.size(0)
        global_step += 1
        
        if i % 10 == 9:
            avg_loss = running_loss / 10
            train_acc = correct_train / total_train
            print(f"[Epoch {epoch+1}, Batch {i+1}] loss: {avg_loss:.3f}  accuracy: {train_acc:.3f}")
            writer.add_scalar('training loss', avg_loss, global_step)
            writer.add_scalar('training accuracy', train_acc, global_step)
            running_loss = 0.0
            correct_train = 0
            total_train = 0
    
    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
    
    # Validation loop with accuracy logging
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for data in val_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device).long()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct_val += torch.sum(preds == labels).item()
            total_val += labels.size(0)
    val_loss_avg = val_loss / len(val_loader)
    val_acc = correct_val / total_val
    print(f"Validation loss after epoch {epoch+1}: {val_loss_avg:.3f}  accuracy: {val_acc:.3f}")
    writer.add_scalar('validation loss', val_loss_avg, epoch)
    writer.add_scalar('validation accuracy', val_acc, epoch)

# Save the model
torch.save(model.state_dict(), "model.pth")
print("Model saved to model.pth")
writer.close()

[Epoch 1, Batch 10] loss: 2.028  accuracy: 0.537
[Epoch 1, Batch 20] loss: 1.118  accuracy: 0.512
[Epoch 1, Batch 30] loss: 0.971  accuracy: 0.512
[Epoch 1, Batch 40] loss: 0.811  accuracy: 0.675
[Epoch 1, Batch 50] loss: 0.969  accuracy: 0.463
[Epoch 1, Batch 60] loss: 0.947  accuracy: 0.625
[Epoch 1, Batch 70] loss: 1.000  accuracy: 0.450
[Epoch 1, Batch 80] loss: 0.907  accuracy: 0.625
[Epoch 1, Batch 90] loss: 0.977  accuracy: 0.487
[Epoch 1, Batch 100] loss: 1.031  accuracy: 0.562
[Epoch 1, Batch 110] loss: 0.919  accuracy: 0.550
[Epoch 1, Batch 120] loss: 0.927  accuracy: 0.625
[Epoch 1, Batch 130] loss: 1.022  accuracy: 0.287
[Epoch 1, Batch 140] loss: 0.909  accuracy: 0.600
[Epoch 1, Batch 150] loss: 0.990  accuracy: 0.525
[Epoch 1, Batch 160] loss: 0.941  accuracy: 0.537
[Epoch 1, Batch 170] loss: 0.951  accuracy: 0.575
[Epoch 1, Batch 180] loss: 0.955  accuracy: 0.562
[Epoch 1, Batch 190] loss: 0.944  accuracy: 0.562
[Epoch 1, Batch 200] loss: 0.922  accuracy: 0.537
[Epoch 1,

KeyboardInterrupt: 