In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import os
import sys
import glob
from pathlib import Path
import nibabel as nib
from torchvision.transforms.functional import pad, center_crop
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms.functional as TF

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)
from func.data import get_loaders
from func.funcs import run_model, print_losses

# Define device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

BATCH_SIZE = 64
IMG_SIZE = 64
NUM_WORKERS = 0

In [74]:
train_loader, test_loader = get_loaders("../archive/", batch_size=BATCH_SIZE, img_size=IMG_SIZE, num_workers=NUM_WORKERS)

IndexError: list index out of range

In [None]:
def double_conv(in_channels, out_channels):
    """
    MODERN U-Net block: Two 3x3 convolutions with padding=1.
    This keeps the HxW dimensions the SAME.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=1):
        super().__init__()
                
        self.conv_down1 = double_conv(in_channels, 16)
        self.conv_down2 = double_conv(16, 32)
        self.conv_down3 = double_conv(32, 64)
        self.conv_down4 = double_conv(64, 128)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.conv_up3 = double_conv(64 + 128, 64)
        self.conv_up2 = double_conv(32 + 64, 32)
        self.conv_up1 = double_conv(32 + 16, 16)
        
        self.last_conv = nn.Conv2d(16, num_classes, kernel_size=1)
        
        
    def forward(self, x):
        # Forward pass through the network
        conv1 = self.conv_down1(x)  
        x = self.maxpool(conv1)     
        conv2 = self.conv_down2(x)  
        x = self.maxpool(conv2)     
        conv3 = self.conv_down3(x) 
        x = self.maxpool(conv3)     
        x = self.conv_down4(x)      
        x = self.upsample(x)        
    

        x = torch.cat([x, conv3], dim=1) 
        x = self.conv_up3(x) 
        x = self.upsample(x)   
        x = torch.cat([x, conv2], dim=1) 
        x = self.conv_up2(x) 
        x = self.upsample(x)      
        x = torch.cat([x, conv1], dim=1) 
        
        x = self.conv_up1(x)
        
        out = self.last_conv(x) 
        out = torch.sigmoid(out)
        
        return out

In [None]:
unet = SimpleUNet().to(device)
output = unet(torch.randn(1,1,512,512).to(device))
print(output.shape)

In [None]:
lr = 1e-4
criterion = nn.BCELoss()  

# weight_decay is equal to L2 regularization
optimizer = optim.Adam(unet.parameters(), lr=lr)

In [None]:
epochs = 300

unet.train()
for epoch in range(epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = unet(images)

        loss = criterion(outputs, labels)
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}], Step [{i+1}], Loss: {loss.item():.4f}')