In [2]:
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch
  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-win_amd64.whl (2449.4 MB)
     ---------------------------------------- 0.0/2.4 GB ? eta -:--:--
     ---------------------------------------- 0.0/2.4 GB 10.5 MB/s eta 0:03:54
     ---------------------------------------- 0.0/2.4 GB 10.7 MB/s eta 0:03:50
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------------- 0.0/2.4 GB 5.2 MB/s eta 0:07:55
     ---------------------------------

ERROR: Could not install packages due to an OSError: [WinError 32] The process cannot access the file because it is being used by another process: 'C:\\Users\\tasni\\AppData\\Local\\Temp\\pip-unpack-xt9j9vtx\\torch-2.5.1+cu121-cp310-cp310-win_amd64.whl'
Consider using the `--user` option or check the permissions.



In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.functional import to_tensor
import matplotlib.pyplot as plt
from tqdm import tqdm
from networks.vision_transformer import SwinUnet



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Config:
    root_path = "./Dataset"  # <-- This must exist!
    img_size = 224
    num_classes = 3
    base_lr = 0.01
    batch_size = 4
    max_epochs = 100
    n_gpu = 1
    num_workers = 4
    eval_interval = 5
    seed = 42
    snapshot_path = "./swin_output"
    pretrained_ckpt = "./pretrained_ckpt/swin_tiny_patch4_window7_224.pth"

args = Config()


# Create output directory if it doesn't exist
os.makedirs(args.snapshot_path, exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(args.seed)


<torch._C.Generator at 0x1cdd6a837f0>

In [3]:
#create HepaticDataset
from PIL import Image   

class HepaticDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.image_files[idx])

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return {
            'image': image.float(),
            'label': mask.long().squeeze()
        }


In [4]:
#define transforms and loader
train_transform = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.ToTensor(),
])

train_dataset = HepaticDataset(
    image_dir=os.path.join(args.root_path, '2D_Sliced_Images'),
    mask_dir=os.path.join(args.root_path, '2D_Sliced_Masks'),
    transform=train_transform
)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)


In [5]:
#define diceloss
class DiceLoss(nn.Module):
    def __init__(self, num_classes):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, input, target, smooth=1e-5, softmax=True):
        if softmax:
            input = torch.softmax(input, dim=1)

        target_onehot = torch.eye(self.num_classes)[target].permute(0, 3, 1, 2).to(input.device)

        dims = (0, 2, 3)
        intersection = torch.sum(input * target_onehot, dims)
        cardinality = torch.sum(input + target_onehot, dims)

        dice = (2. * intersection + smooth) / (cardinality + smooth)
        return 1. - dice.mean()


In [6]:
# loads a YAML configuration file and converts it into a CfgNode
from yacs.config import CfgNode as CN
import yaml

# Load YAML into a dictionary
with open('configs/swin_tiny_patch4_window7_224_lite.yaml', 'r') as f:
    yaml_cfg = yaml.safe_load(f)

# Convert dictionary to CfgNode (nested access support)
config = CN(yaml_cfg)


In [7]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU found")
print("PyTorch CUDA version:", torch.version.cuda)



Torch version: 2.5.1+cu121
CUDA available: True
GPU: NVIDIA GeForce GTX 1650
PyTorch CUDA version: 12.1


In [7]:
#Load SwinUnet model
from yacs.config import CfgNode as CN
import torch
import os
from networks.vision_transformer import SwinUnet

# ---- Load YAML config ----
from yaml import safe_load

with open('configs/swin_tiny_patch4_window7_224_lite.yaml', 'r') as f:
    yaml_cfg = safe_load(f)

config = CN(yaml_cfg)

# ---- Set extra training args ----
class Args:
    root_path = "./Dataset"
    img_size = config.DATA.IMG_SIZE
    num_classes = 3
    base_lr = 0.01
    batch_size = 4
    max_epochs = 100
    n_gpu = 1
    num_workers = 4
    eval_interval = 5
    seed = 42
    snapshot_path = "./swin_output"

args = Args()

# ---- Ensure output folder exists ----
os.makedirs(args.snapshot_path, exist_ok=True)
torch.manual_seed(args.seed)

# ---- Step 1: Create the model ----
model = SwinUnet(
    config=config,
    img_size=args.img_size,
    num_classes=args.num_classes,
    zero_head=True
)

# ---- Step 2: Load pretrained weights from config path ----
model.load_from(config)

# ---- Step 3: Wrap in DataParallel if multiple GPUs ----
if args.n_gpu > 1:
    model = torch.nn.DataParallel(model)

# ---- Step 4: Move to GPU ----
model = model.cuda()


SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:3


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


---final upsample expand_first---
pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth


  pretrained_dict = torch.load(pretrained_path, map_location=device)


---start load pretrained modle of swin encoder---


In [8]:
#OPtimizer and loss functions
ce_loss = nn.CrossEntropyLoss()
dice_loss = DiceLoss(num_classes=args.num_classes)

optimizer = optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)


In [None]:
# --- Training Setup ---
max_iterations = args.max_epochs * len(train_loader)
best_loss = float("inf")
iter_num = 0

# --- Early Stopping Parameters ---
patience = 10  # you can change this to 5, 20, etc.
trigger_times = 0

for epoch in range(args.max_epochs):
    model.train()
    epoch_ce = 0
    epoch_dice = 0

    for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}"):
        images = batch['image'].cuda()
        labels = batch['label'].cuda()

        outputs = model(images)
        loss_ce = ce_loss(outputs, labels)
        loss_dice = dice_loss(outputs, labels, softmax=True)
        loss = 0.4 * loss_ce + 0.6 * loss_dice

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_num += 1
        lr_ = args.base_lr * (1.0 - iter_num / max_iterations) ** 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        epoch_ce += loss_ce.item()
        epoch_dice += loss_dice.item()

    epoch_ce /= len(train_loader)
    epoch_dice /= len(train_loader)
    total_loss = 0.4 * epoch_ce + 0.6 * epoch_dice
    print(f"Epoch {epoch} - Loss: {total_loss:.4f} | CE: {epoch_ce:.4f} | Dice: {epoch_dice:.4f}")

    # --- Early Stopping Logic ---
    if total_loss < best_loss:
        torch.save(model.state_dict(), os.path.join(args.snapshot_path, 'best_model.pth'))
        best_loss = total_loss
        trigger_times = 0  # reset counter on improvement
    else:
        torch.save(model.state_dict(), os.path.join(args.snapshot_path, 'last_model.pth'))
        trigger_times += 1
        print(f"EarlyStopping counter: {trigger_times} out of {patience}")
        
        if trigger_times >= patience:
            print("Early stopping triggered. Stopping training.")
            break
