In [4]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
import pandas as pd
import timm
import torch.nn.functional as F

# If using your custom data loader, import it (or define your own here)
from data import build_split_dataloaders  # ...existing code...
from model import SwinTransformerClassificationModel  # New import for SwinTransformer

# Hyperparameters
batch_size = 6
learning_rate = 0.001
num_epochs = 10
experiment = "convnextv2"
log_dir = f"runs/{experiment}"

# Data paths
root_dir = os.path.join("K:", "rsna-breast-cancer-detection")
csv_path = os.path.join(root_dir, "train.csv")
root_dir = os.path.join(root_dir, "train_images_cropped")

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    # RESIZE TO 224x224
    transforms.Resize((384, 384)),
])

# Build dataloaders (Assumes build_split_dataloaders is defined in data.py)
train_loader, val_loader, test_loader = build_split_dataloaders(
    csv_path, root_dir, batch_size=batch_size, transform=transform, train=True, val_ratio=0.2, test_ratio=0.1, paired=True
)

print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

Train batches: 976, Validation batches: 279


In [5]:
from data import MedicalImageDataset

# Create a new dataset to check the class distribution
dataset = MedicalImageDataset(csv_path, root_dir, transform=transform, paired=True)
dataframe = dataset.metadata
cc_birads = dataframe["cc_birads"].values
mlo_birads = dataframe["mlo_birads"].values

birads_counts = cc_birads.tolist() + mlo_birads.tolist()
birads_counts = pd.Series(birads_counts).value_counts()

print(birads_counts)

# Create weights for loss function based on class distribution
birads_weights = 1 / birads_counts
birads_weights = birads_weights / birads_weights.sum()
birads_weights = birads_weights.sort_index()

birads_weights = torch.tensor(birads_weights.values).float()
birads_weights = birads_weights.to("cuda")

print(birads_weights)

0.0    8756
1.0    6266
2.0    1692
Name: count, dtype: int64
tensor([0.1321, 0.1845, 0.6834], device='cuda:0')


In [6]:
from model import SwinTransformerClassificationModel, SwinMammoClassifier, ConvNeXtClassificationModel
# Initialize model, criterion, optimizer, and TensorBoard writer

num_classes = 3

# model = SwinTransformerClassificationModel(num_classes=num_classes)
# model = SwinMammoClassifier(num_classes=num_classes)
model = ConvNeXtClassificationModel(num_classes=num_classes)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

writer = SummaryWriter(log_dir=log_dir)

print(f"Using device: {device}")

Using device: cuda


In [None]:
import numpy as np

# Training loop with extra logging
global_step = 0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train_batch = 0
    correct_total = 0
    total_train_batch = 0
    total_train = 0
    total_loss = 0.0
    epoch_start = time.time()
    
    for i, data in enumerate(train_loader):
        inputs, labels = data
        
        mlo, cc = inputs
        mlo_labels, cc_labels = labels
        mlo = mlo.to(device)
        cc = cc.to(device)
        mlo_labels = mlo_labels.to(device).long()
        cc_labels = cc_labels.to(device).long()
        
        optimizer.zero_grad()
        
        model_input = torch.cat((mlo, cc), dim=0)
        outputs = model(model_input)
        
        mlo_outputs, cc_outputs = torch.split(outputs, mlo.size(0), dim=0)
        
        # mlo_outputs, cc_outputs = model(mlo, cc)
        
        # Ensure the batch sizes match
        mlo_outputs = mlo_outputs.view(-1, num_classes)
        cc_outputs = cc_outputs.view(-1, num_classes)
        mlo_labels = mlo_labels.view(-1)
        cc_labels = cc_labels.view(-1)
        
        loss = criterion(mlo_outputs, mlo_labels) + criterion(cc_outputs, cc_labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        # Compute predictions
        _, mlo_preds = torch.max(mlo_outputs, 1)
        _, cc_preds = torch.max(cc_outputs, 1)
        
        correct_train_batch += torch.sum(mlo_preds == mlo_labels).item() + torch.sum(cc_preds == cc_labels).item()
        
        
        total_train_batch += mlo_labels.size(0) + cc_labels.size(0)
        global_step += 1
        
        # _, preds = torch.max(outputs, 1)
        # correct_train += torch.sum(preds == labels).item()
        # total_train += labels.size(0)
        # global_step += 1
        
        if i % 10 == 9:
            avg_loss = running_loss / 10 / 2
            
            total_train += total_train_batch
            correct_total += correct_train_batch
            total_loss += running_loss
            
            avg_total_loss = total_loss / (i+1)
            
            train_acc_batch = correct_train_batch / total_train_batch
            print(f"[Epoch {epoch+1}, Batch {i+1}] loss: {avg_loss:.3f}  accuracy: {train_acc_batch:.3f} total accuracy: {correct_total / total_train:.3f} avg loss: {avg_total_loss:.3f}")
            writer.add_scalar('training loss', avg_loss, global_step)
            writer.add_scalar('training accuracy', train_acc_batch, global_step)
            
            writer.add_scalar('total training accuracy', correct_total / total_train, global_step)
            writer.add_scalar('total training loss', avg_total_loss, global_step)
            
            running_loss = 0.0
            correct_train_batch = 0
            total_train_batch = 0
    
    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
    
    # Validation loop with same logic as training
    model.eval()
    running_loss = 0.0
    correct_batch = 0
    total_batch = 0
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, labels = data
            mlo, cc = inputs
            mlo_labels, cc_labels = labels
            mlo = mlo.to(device)
            cc = cc.to(device)
            mlo_labels = mlo_labels.to(device).long()
            cc_labels = cc_labels.to(device).long()
            model_input = torch.cat((mlo, cc), dim=0)
            outputs = model(model_input)
            mlo_outputs, cc_outputs = torch.split(outputs, mlo.size(0), dim=0)
            mlo_outputs = mlo_outputs.view(-1, num_classes)
            cc_outputs = cc_outputs.view(-1, num_classes)
            mlo_labels = mlo_labels.view(-1)
            cc_labels = cc_labels.view(-1)
            loss = criterion(mlo_outputs, mlo_labels) + criterion(cc_outputs, cc_labels)
            running_loss += loss.item()
            _, mlo_preds = torch.max(mlo_outputs, 1)
            _, cc_preds = torch.max(cc_outputs, 1)
            correct = torch.sum(mlo_preds == mlo_labels).item() + torch.sum(cc_preds == cc_labels).item()
            batch_total = mlo_labels.size(0) + cc_labels.size(0)
            correct_batch += correct
            total_batch += batch_total
            total_loss += loss.item()
            total_correct += correct
            total_samples += batch_total
            if i % 10 == 9:
                avg_loss = running_loss / 10 / 2
                batch_acc = correct_batch / total_batch
                writer.add_scalar('validation loss', avg_loss, i)
                writer.add_scalar('validation accuracy', batch_acc, i)
                running_loss = 0.0
                correct_batch = 0
                total_batch = 0
    val_loss_avg = total_loss / len(val_loader) / 2
    val_acc = total_correct / total_samples
    print(f"Validation loss after epoch {epoch+1}: {val_loss_avg:.3f}  accuracy: {val_acc:.3f}")
    writer.add_scalar('validation loss', val_loss_avg, epoch)
    writer.add_scalar('validation accuracy', val_acc, epoch)

# Save the model
torch.save(model.state_dict(), f"model_{experiment}.pth")
print("Model saved to model.pth")
writer.close()

[Epoch 1, Batch 10] loss: 1.410  accuracy: 0.433 total accuracy: 0.433 avg loss: 2.819
[Epoch 1, Batch 20] loss: 1.228  accuracy: 0.467 total accuracy: 0.450 avg loss: 2.638
[Epoch 1, Batch 30] loss: 1.047  accuracy: 0.400 total accuracy: 0.433 avg loss: 2.456
[Epoch 1, Batch 40] loss: 0.995  accuracy: 0.517 total accuracy: 0.454 avg loss: 2.340
[Epoch 1, Batch 50] loss: 0.889  accuracy: 0.583 total accuracy: 0.480 avg loss: 2.227
[Epoch 1, Batch 60] loss: 0.950  accuracy: 0.583 total accuracy: 0.497 avg loss: 2.173
[Epoch 1, Batch 70] loss: 0.952  accuracy: 0.583 total accuracy: 0.510 avg loss: 2.135
[Epoch 1, Batch 80] loss: 1.002  accuracy: 0.400 total accuracy: 0.496 avg loss: 2.118
[Epoch 1, Batch 90] loss: 0.969  accuracy: 0.583 total accuracy: 0.506 avg loss: 2.098
[Epoch 1, Batch 100] loss: 1.028  accuracy: 0.350 total accuracy: 0.490 avg loss: 2.094
[Epoch 1, Batch 110] loss: 0.870  accuracy: 0.617 total accuracy: 0.502 avg loss: 2.062
[Epoch 1, Batch 120] loss: 0.970  accurac