In [1]:
# finetune.py
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '0'
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import wandb
import random
from monai.networks.nets import SwinUNETR
from monai.transforms import (
    Compose,
    RandFlipd,
    RandRotate90d,
    OneOf,
)
import datetime
import platform
from dataset import MicrotubuleDataset

# ================== Configuration Variables ==================
EXPERIMENT_NAME = 'swinunetr_finetuning'
# Data paths
IMAGE_DIR = 'training/subvols/image'
LABEL_DIR = 'training/labeled_sdt'
CHECKPOINT_DIR = 'checkpoints'

# Training parameters
BATCH_SIZE = 1
NUM_EPOCHS = 200
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-5
NUM_WORKERS = 0
RANDOM_SEED = 42
PATIENCE = 20

# Model parameters
SPATIAL_DIMS = 3
IN_CHANNELS = 1
OUT_CHANNELS = 1
FEATURE_SIZE = 96  # Reduced from default 48
DROP_RATE = 0.3
ATTN_DROP_RATE = 0.3
DROPOUT_PATH_RATE = 0.3

# Device configuration
DEVICE = torch.device('cuda')
print(f"Using device: {DEVICE}")

# Set random seeds
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

hyperparameters = {
    'experiment_name': EXPERIMENT_NAME,
    'batch_size': BATCH_SIZE,
    'num_epochs': NUM_EPOCHS,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'feature_size': FEATURE_SIZE,
    'drop_rate': DROP_RATE,
    'attn_drop_rate': ATTN_DROP_RATE,
    'dropout_path_rate': DROPOUT_PATH_RATE,
    'device': str(DEVICE),
    'torch_version': torch.__version__,
    'python_version': platform.python_version(),
    'training_start_time': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}

# ================== Model Definition ==================
class TanhSwinUNETR(nn.Module):
    def __init__(self, img_size, in_channels, out_channels, feature_size=48, 
                 drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0):
        super().__init__()
        self.model = SwinUNETR(
            img_size=img_size,
            in_channels=in_channels,
            out_channels=out_channels,
            feature_size=feature_size,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            dropout_path_rate=dropout_path_rate
        )
    
    def forward(self, x):
        x = self.model(x)
        return torch.tanh(x)

def initialize_model(pretrained=True):
    """Initialize SwinUNETR with pretrained weights"""
    model = TanhSwinUNETR(
        img_size=(96, 96, 96),
        in_channels=IN_CHANNELS,
        out_channels=OUT_CHANNELS,
        feature_size=FEATURE_SIZE,
        drop_rate=DROP_RATE,
        attn_drop_rate=ATTN_DROP_RATE,
        dropout_path_rate=DROPOUT_PATH_RATE
    )
    
    if pretrained:
        try:
            from monai.apps import download_and_extract
            resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt"
            weightsfile = download_and_extract(resource)
            
            pretrained_weights = torch.load(weightsfile, map_location='cuda')
            model_dict = model.model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_weights.items() 
                             if k in model_dict and v.shape == model_dict[k].shape}
            
            model_dict.update(pretrained_dict)
            model.model.load_state_dict(model_dict, strict=False)
            print("Loaded pretrained weights")
            
            # Freeze encoder layers
            for name, param in model.model.named_parameters():
                if 'encoder' in name:
                    param.requires_grad = False
                    print(f"Froze {name}")
                    
        except Exception as e:
            print(f"Could not load pretrained weights: {e}")
            print("Initializing with random weights")

    # Move model to CUDA and wrap with DataParallel if multiple GPUs available
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs")
            model = torch.nn.DataParallel(model)
        else:
            print('using 1 gpu')
    else:
        print('cuda not available')
    model = model.to(DEVICE)
    return model

# ================== Training Functions ==================
def get_transforms():
    return Compose([
        OneOf([
            Compose([]),  # No flips
            Compose([RandFlipd(keys=['image', 'label'], prob=1.0, spatial_axis=0)]),
            Compose([RandFlipd(keys=['image', 'label'], prob=1.0, spatial_axis=1)]),
            Compose([RandFlipd(keys=['image', 'label'], prob=1.0, spatial_axis=2)]),
        ], weights=[1/4] * 4),
        
        OneOf([
            Compose([]),  # No rotation
            Compose([RandRotate90d(keys=['image', 'label'], prob=1.0, max_k=1, spatial_axes=(0, 1))]),
            Compose([RandRotate90d(keys=['image', 'label'], prob=1.0, max_k=1, spatial_axes=(1, 2))]),
            Compose([RandRotate90d(keys=['image', 'label'], prob=1.0, max_k=1, spatial_axes=(0, 2))]),
        ], weights=[1/4] * 4)
    ])

def setup_training(model):
    encoder_params = []
    decoder_params = []
    
    for name, param in model.named_parameters():
        if param.requires_grad:  # Only include unfrozen parameters
            if 'encoder' in name:
                encoder_params.append(param)
            else:
                decoder_params.append(param)
    
    optimizer = optim.AdamW([
        {'params': encoder_params, 'lr': LEARNING_RATE * 0.1},
        {'params': decoder_params, 'lr': LEARNING_RATE}
    ], weight_decay=WEIGHT_DECAY)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        verbose=True
    )
    
    return optimizer, scheduler

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler):
    wandb.init(project="microtubule-segmentation", name=EXPERIMENT_NAME, config=hyperparameters)
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(NUM_EPOCHS):
        # Training phase
        model.train()
        train_loss = 0
        
        for images, labels in train_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Logging
        wandb.log({
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'epoch': epoch + 1,
            'learning_rate': optimizer.param_groups[0]['lr']
        })
        
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss: {avg_train_loss:.5f} Val Loss: {avg_val_loss:.5f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, f"{CHECKPOINT_DIR}/best_model.pt")
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= PATIENCE:
            print(f"Early stopping triggered after epoch {epoch+1}")
            break
    
    wandb.finish()

# ================== Main Function ==================
def main():
    # Create checkpoint directory
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    
    # Initialize model
    model = initialize_model(pretrained=True)
    model = model.to(DEVICE)
    
    # Print parameter summary
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,}")
    
    # Create datasets
    transforms = get_transforms()
    
    dataset = MicrotubuleDataset(
        image_dir=IMAGE_DIR,
        label_dir=LABEL_DIR
    )
    
    # Split indices
    dataset_length = len(dataset)
    val_indices = list(range(0, dataset_length, 5))
    train_indices = [i for i in range(dataset_length) if i not in val_indices]
    
    # Create train/val datasets
    train_dataset = MicrotubuleDataset(
        image_dir=IMAGE_DIR,
        label_dir=LABEL_DIR,
        indices=train_indices,
        transforms=transforms
    )
    
    val_dataset = MicrotubuleDataset(
        image_dir=IMAGE_DIR,
        label_dir=LABEL_DIR,
        indices=val_indices,
        transforms=transforms
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    # Setup training
    optimizer, scheduler = setup_training(model)
    
    # Loss function
    criterion = nn.L1Loss()
    
    # Train model
    train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler
    )

if __name__ == "__main__":
    main()

2024-11-26 20:32:57.825110: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-26 20:32:57.837979: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-26 20:32:57.852825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-26 20:32:57.857281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-26 20:32:57.868659: I tensorflow/core/platform/cpu_feature_guar

Using device: cuda


monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward().
swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt: 244MB [00:07, 36.4MB/s]                              

2024-11-26 20:33:12,492 - INFO - Downloaded: /tmp/tmpszvpn9eu/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt
2024-11-26 20:33:12,493 - INFO - Expected md5 is None, skip md5 check for file /tmp/tmpszvpn9eu/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt.
2024-11-26 20:33:12,493 - INFO - Writing into directory: ..





Could not load pretrained weights: Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name=/tmp/tmpszvpn9eu/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt type=.
Initializing with random weights
using 1 gpu

Trainable parameters: 248,089,315 / 248,089,315
Found 48 valid image-label pairs
Found 48 valid image-label pairs
Found 48 valid image-label pairs
Training samples: 38
Validation samples: 10


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maatmik54[0m ([33maatmikmallya[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch [1/200] Train Loss: 0.13488 Val Loss: 0.05259
Epoch [2/200] Train Loss: 0.05343 Val Loss: 0.04731
Epoch [3/200] Train Loss: 0.04783 Val Loss: 0.04612


KeyboardInterrupt: 