In [5]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
import pandas as pd
import timm
import torch.nn.functional as F

# If using your custom data loader, import it (or define your own here)
from data import build_split_dataloaders  # ...existing code...
from model import SwinTransformerClassificationModel  # New import for SwinTransformer

# Hyperparameters
batch_size = 8
learning_rate = 0.001
num_epochs = 10
experiment = "swin_attention_weighted"
log_dir = f"runs/{experiment}"

# Data paths
root_dir = os.path.join("K:", "rsna-breast-cancer-detection")
csv_path = os.path.join(root_dir, "train.csv")
root_dir = os.path.join(root_dir, "train_images_cropped")

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    # RESIZE TO 224x224
    transforms.Resize((224, 224)),
])

# Build dataloaders (Assumes build_split_dataloaders is defined in data.py)
train_loader, val_loader, test_loader = build_split_dataloaders(
    csv_path, root_dir, batch_size=batch_size, transform=transform, train=True, val_ratio=0.2, test_ratio=0.1, paired=True
)

print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

Train batches: 732, Validation batches: 209


In [6]:
from data import MedicalImageDataset

# Create a new dataset to check the class distribution
dataset = MedicalImageDataset(csv_path, root_dir, transform=transform, paired=True)
dataframe = dataset.metadata
cc_birads = dataframe["cc_birads"].values
mlo_birads = dataframe["mlo_birads"].values

birads_counts = cc_birads.tolist() + mlo_birads.tolist()
birads_counts = pd.Series(birads_counts).value_counts()

print(birads_counts)

# Create weights for loss function based on class distribution
birads_weights = 1 / birads_counts
birads_weights = birads_weights / birads_weights.sum()
birads_weights = birads_weights.sort_index()

birads_weights = torch.tensor(birads_weights.values).float()
birads_weights = birads_weights.to("cuda")

print(birads_weights)

0.0    8756
1.0    6266
2.0    1692
Name: count, dtype: int64
tensor([0.1321, 0.1845, 0.6834], device='cuda:0')


In [None]:
from model import SwinTransformerClassificationModel
# Initialize model, criterion, optimizer, and TensorBoard writer

# Use SwinTransformer instead of resnet18
num_classes = 3

model = SwinTransformerClassificationModel(num_classes=num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

writer = SummaryWriter(log_dir=log_dir)

print(f"Using device: {device}")

Using device: cuda


In [7]:
from model import SwinMammoClassifier

# Initialize model, criterion, optimizer, and TensorBoard writer

num_classes = 3

model = SwinMammoClassifier(num_classes=num_classes)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss(reduction="none")
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

writer = SummaryWriter(log_dir=log_dir)

print(f"Using device: {device}")

Using device: cuda


In [None]:
import numpy as np

# Training loop with extra logging
global_step = 0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train_batch = 0
    correct_total = 0
    total_train_batch = 0
    total_train = 0
    total_loss = 0.0
    epoch_start = time.time()
    
    for i, data in enumerate(train_loader):
        inputs, labels = data
        
        mlo, cc = inputs
        mlo_labels, cc_labels = labels
        mlo = mlo.to(device)
        cc = cc.to(device)
        mlo_labels = mlo_labels.to(device).long()
        cc_labels = cc_labels.to(device).long()
        
        optimizer.zero_grad()
        mlo_outputs, cc_outputs = model(mlo, cc)
        
        # Ensure the batch sizes match
        mlo_outputs = mlo_outputs.view(-1, num_classes)
        cc_outputs = cc_outputs.view(-1, num_classes)
        mlo_labels = mlo_labels.view(-1)
        cc_labels = cc_labels.view(-1)
        
        loss = criterion(mlo_outputs, mlo_labels) + criterion(cc_outputs, cc_labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        # Compute predictions
        _, mlo_preds = torch.max(mlo_outputs, 1)
        _, cc_preds = torch.max(cc_outputs, 1)
        
        correct_train_batch += torch.sum(mlo_preds == mlo_labels).item() + torch.sum(cc_preds == cc_labels).item()
        
        
        total_train_batch += mlo_labels.size(0) + cc_labels.size(0)
        global_step += 1
        
        # _, preds = torch.max(outputs, 1)
        # correct_train += torch.sum(preds == labels).item()
        # total_train += labels.size(0)
        # global_step += 1
        
        if i % 10 == 9:
            avg_loss = running_loss / 10
            
            total_train += total_train_batch
            correct_total += correct_train_batch
            total_loss += running_loss
            
            avg_total_loss = total_loss / (i+1)
            
            train_acc_batch = correct_train_batch / total_train_batch
            print(f"[Epoch {epoch+1}, Batch {i+1}] loss: {avg_loss:.3f}  accuracy: {train_acc_batch:.3f} total accuracy: {correct_total / total_train:.3f} avg loss: {avg_total_loss:.3f}")
            writer.add_scalar('training loss', avg_loss, global_step)
            writer.add_scalar('training accuracy', train_acc_batch, global_step)
            
            writer.add_scalar('total training accuracy', correct_total / total_train, global_step)
            writer.add_scalar('total training loss', avg_total_loss, global_step)
            
            running_loss = 0.0
            correct_train_batch = 0
            total_train_batch = 0
    
    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
    
    # Validation loop with accuracy logging
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for data in val_loader:
            inputs, labels = data
            mlo, cc = inputs
            mlo_labels, cc_labels = labels
            mlo = mlo.to(device)
            cc = cc.to(device)
            mlo_labels = mlo_labels.to(device).long()
            cc_labels = cc_labels.to(device).long()
            
            mlo_outputs, cc_outputs = model(mlo, cc)
            
            # Ensure the batch sizes match
            mlo_outputs = mlo_outputs.view(-1, num_classes)
            cc_outputs = cc_outputs.view(-1, num_classes)
            mlo_labels = mlo_labels.view(-1)
            cc_labels = cc_labels.view(-1)
            
            loss = criterion(mlo_outputs, mlo_labels) + criterion(cc_outputs, cc_labels)
            val_loss += loss.item()
            
            _, mlo_preds = torch.max(mlo_outputs, 1)
            _, cc_preds = torch.max(cc_outputs, 1)
            correct_val += torch.sum(mlo_preds == mlo_labels).item() + torch.sum(cc_preds == cc_labels).item()
            total_val += mlo_labels.size(0) + cc_labels.size(0)
    val_loss_avg = val_loss / len(val_loader)
    val_acc = correct_val / total_val
    print(f"Validation loss after epoch {epoch+1}: {val_loss_avg:.3f}  accuracy: {val_acc:.3f}")
    writer.add_scalar('validation loss', val_loss_avg, epoch)
    writer.add_scalar('validation accuracy', val_acc, epoch)

# Save the model
torch.save(model.state_dict(), f"model_{experiment}.pth")
print("Model saved to model.pth")
writer.close()

[Epoch 1, Batch 10] loss: 0.480  accuracy: 0.106 total accuracy: 0.106 avg loss: 0.480
[Epoch 1, Batch 20] loss: 0.426  accuracy: 0.412 total accuracy: 0.259 avg loss: 0.453
[Epoch 1, Batch 30] loss: 0.480  accuracy: 0.281 total accuracy: 0.267 avg loss: 0.462
[Epoch 1, Batch 40] loss: 0.467  accuracy: 0.406 total accuracy: 0.302 avg loss: 0.463
[Epoch 1, Batch 50] loss: 0.432  accuracy: 0.487 total accuracy: 0.339 avg loss: 0.457
[Epoch 1, Batch 60] loss: 0.442  accuracy: 0.412 total accuracy: 0.351 avg loss: 0.455
[Epoch 1, Batch 70] loss: 0.516  accuracy: 0.325 total accuracy: 0.347 avg loss: 0.464
[Epoch 1, Batch 80] loss: 0.520  accuracy: 0.163 total accuracy: 0.324 avg loss: 0.471
[Epoch 1, Batch 90] loss: 0.456  accuracy: 0.188 total accuracy: 0.309 avg loss: 0.469
[Epoch 1, Batch 100] loss: 0.375  accuracy: 0.450 total accuracy: 0.323 avg loss: 0.460
[Epoch 1, Batch 110] loss: 0.478  accuracy: 0.325 total accuracy: 0.323 avg loss: 0.461
[Epoch 1, Batch 120] loss: 0.489  accurac

KeyboardInterrupt: 