In [None]:
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [1]:
!pip install scikit-learn


Collecting scikit-learn
  Downloading scikit_learn-1.7.0-cp310-cp310-win_amd64.whl.metadata (14 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Downloading joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.7.0-cp310-cp310-win_amd64.whl (10.7 MB)
   ---------------------------------------- 0.0/10.7 MB ? eta -:--:--
   - -------------------------------------- 0.5/10.7 MB 5.6 MB/s eta 0:00:02
   ---------- ----------------------------- 2.9/10.7 MB 9.3 MB/s eta 0:00:01
   ----------------- ---------------------- 4.7/10.7 MB 9.2 MB/s eta 0:00:01
   ------------------------ --------------- 6.6/10.7 MB 9.1 MB/s eta 0:00:01
   --------------------------------- ------ 8.9/10.7 MB 9.7 MB/s eta 0:00:01
   ---------------------------------------- 10.7/10.7 MB 9.8 MB/s eta 0:00:00
Downloading joblib-1.5.1-py3-none-any.whl (307 kB)
Downloading threadpool

In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.functional import to_tensor
import matplotlib.pyplot as plt
from tqdm import tqdm
from networks.vision_transformer import SwinUnet



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Config:
    root_path = "./Dataset"  # <-- This must exist!
    img_size = 64
    num_classes = 3
    base_lr = 0.01
    batch_size = 4
    max_epochs = 5
    n_gpu = 1
    num_workers = 4
    eval_interval = 5
    seed = 42
    snapshot_path = "./swin_output"
    pretrained_ckpt = "./pretrained_ckpt/swin_tiny_patch4_window7_224.pth"

args = Config()


# Create output directory if it doesn't exist
os.makedirs(args.snapshot_path, exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(args.seed)


<torch._C.Generator at 0x2886e2ff810>

In [4]:
#create HepaticDataset
from PIL import Image   

class HepaticDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.image_files[idx])

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return {
            'image': image.float(),
            'label': mask.long().squeeze()
        }


In [5]:
from sklearn.model_selection import train_test_split

In [6]:
#define transforms and loader
train_transform = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.ToTensor(),
])



# Get file list
image_paths = sorted(os.listdir(os.path.join(args.root_path, '2D_Sliced_Images')))
train_files, val_files = train_test_split(image_paths, test_size=0.2, random_state=args.seed)

# Update HepaticDataset to take file_list
class HepaticDataset(Dataset):
    def __init__(self, image_dir, mask_dir, file_list, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = file_list
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.image_files[idx])
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        return {'image': image.float(), 'label': mask.long().squeeze()}

# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.ToTensor(),
])

# Define datasets and loaders
train_dataset = HepaticDataset(
    os.path.join(args.root_path, '2D_Sliced_Images'),
    os.path.join(args.root_path, '2D_Sliced_Masks'),
    train_files,
    transform=train_transform
)

val_dataset = HepaticDataset(
    os.path.join(args.root_path, '2D_Sliced_Images'),
    os.path.join(args.root_path, '2D_Sliced_Masks'),
    val_files,
    transform=train_transform
)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [7]:
#define diceloss
class DiceLoss(nn.Module):
    def __init__(self, num_classes):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, input, target, smooth=1e-5, softmax=True):
        if softmax:
            input = torch.softmax(input, dim=1)

        target_onehot = torch.eye(self.num_classes)[target].permute(0, 3, 1, 2).to(input.device)

        dims = (0, 2, 3)
        intersection = torch.sum(input * target_onehot, dims)
        cardinality = torch.sum(input + target_onehot, dims)

        dice = (2. * intersection + smooth) / (cardinality + smooth)
        return 1. - dice.mean()


In [8]:
# loads a YAML configuration file and converts it into a CfgNode
from yacs.config import CfgNode as CN
import yaml

# Load YAML into a dictionary
with open('configs/swin_tiny_patch4_window7_224_lite.yaml', 'r') as f:
    yaml_cfg = yaml.safe_load(f)

# Convert dictionary to CfgNode (nested access support)
config = CN(yaml_cfg)


In [9]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU found")
print("PyTorch CUDA version:", torch.version.cuda)



Torch version: 2.5.1+cu121
CUDA available: True
GPU: NVIDIA GeForce GTX 1650
PyTorch CUDA version: 12.1


In [10]:
#Load SwinUnet model
from yacs.config import CfgNode as CN
import torch
import os
from networks.vision_transformer import SwinUnet

# ---- Load YAML config ----
from yaml import safe_load

with open('configs/swin_tiny_patch4_window7_224_lite.yaml', 'r') as f:
    yaml_cfg = safe_load(f)

config = CN(yaml_cfg)

# ---- Set extra training args ----
class Args:
    root_path = "./Dataset"
    img_size = config.DATA.IMG_SIZE
    num_classes = 3
    base_lr = 0.01
    batch_size = 4
    max_epochs = 5
    n_gpu = 1
    num_workers = 4
    eval_interval = 5
    seed = 42
    snapshot_path = "./swin_output"

args = Args()

# ---- Ensure output folder exists ----
os.makedirs(args.snapshot_path, exist_ok=True)
torch.manual_seed(args.seed)

# ---- Step 1: Create the model ----
model = SwinUnet(
    config=config,
    img_size=args.img_size,
    num_classes=args.num_classes,
    zero_head=True
)

# ---- Step 2: Load pretrained weights from config path ----
model.load_from(config)

# ---- Step 3: Wrap in DataParallel if multiple GPUs ----
if args.n_gpu > 1:
    model = torch.nn.DataParallel(model)

# ---- Step 4: Move to GPU ----
model = model.cuda()


SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:3


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


---final upsample expand_first---
pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth


  pretrained_dict = torch.load(pretrained_path, map_location=device)


---start load pretrained modle of swin encoder---


In [11]:
#OPtimizer and loss functions
ce_loss = nn.CrossEntropyLoss()
dice_loss = DiceLoss(num_classes=args.num_classes)

optimizer = optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)


In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt

# --- Training Setup ---
max_iterations = args.max_epochs * len(train_loader)
best_loss = float("inf")
iter_num = 0

# --- Early Stopping Parameters ---
patience = 10
trigger_times = 0

# --- Metric tracking lists ---
epoch_losses = []
epoch_accuracies = []
epoch_dice_scores = []
epoch_ious = []

val_losses = []
val_accuracies = []
val_dice_scores = []
val_ious = []

def compute_iou(pred, target, num_classes):
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds[target_inds]).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return np.nanmean(ious)

for epoch in range(args.max_epochs):
    model.train()
    epoch_ce = 0
    epoch_dice = 0
    epoch_accuracy = 0
    epoch_iou = 0

    start_time = time.time()

    for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}"):
        images = batch['image'].cuda()
        labels = batch['label'].cuda()

        outputs = model(images)
        loss_ce = ce_loss(outputs, labels)
        loss_dice = dice_loss(outputs, labels, softmax=True)
        loss = 0.4 * loss_ce + 0.6 * loss_dice

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_num += 1
        lr_ = args.base_lr * (1.0 - iter_num / max_iterations) ** 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        epoch_ce += loss_ce.item()
        epoch_dice += loss_dice.item()

        preds = torch.argmax(outputs, dim=1)
        correct = (preds == labels).float().mean().item()
        epoch_accuracy += correct
        epoch_iou += compute_iou(preds, labels, args.num_classes)

    # Averages for the epoch
    epoch_ce /= len(train_loader)
    epoch_dice /= len(train_loader)
    epoch_accuracy /= len(train_loader)
    epoch_iou /= len(train_loader)
    total_loss = 0.4 * epoch_ce + 0.6 * epoch_dice

    epoch_losses.append(total_loss)
    epoch_accuracies.append(epoch_accuracy)
    epoch_dice_scores.append(1 - epoch_dice)  # Dice Score = 1 - Dice Loss
    epoch_ious.append(epoch_iou)

    end_time = time.time()
    print(f"Epoch {epoch} took {(end_time - start_time)/60:.2f} minutes")
    print(f" → Train Loss: {total_loss:.4f} | CE: {epoch_ce:.4f} | Dice Loss: {epoch_dice:.4f}")
    print(f" → Train Accuracy: {epoch_accuracy:.4f} | Train IoU: {epoch_iou:.4f}")

    # --- Validation Phase ---
    model.eval()
    val_loss = 0
    val_correct = 0
    val_iou = 0
    val_dice = 0

    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].cuda()
            labels = batch['label'].cuda()

            outputs = model(images)
            loss_ce_val = ce_loss(outputs, labels)
            loss_dice_val = dice_loss(outputs, labels)
            loss_val = 0.4 * loss_ce_val + 0.6 * loss_dice_val
            val_loss += loss_val.item()

            preds = torch.argmax(outputs, dim=1)
            val_correct += (preds == labels).float().mean().item()
            val_iou += compute_iou(preds, labels, args.num_classes)
            val_dice += 1 - loss_dice_val.item()

    val_loss /= len(val_loader)
    val_acc = val_correct / len(val_loader)
    val_iou /= len(val_loader)
    val_dice /= len(val_loader)

    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_ious.append(val_iou)
    val_dice_scores.append(val_dice)

    print(f" → Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.4f} | Val IoU: {val_iou:.4f} | Val Dice: {val_dice:.4f}")

    # --- Early Stopping Logic ---
    if total_loss < best_loss:
        torch.save(model.state_dict(), os.path.join(args.snapshot_path, 'best_model.pth'))
        best_loss = total_loss
        trigger_times = 0
    else:
        torch.save(model.state_dict(), os.path.join(args.snapshot_path, 'last_model.pth'))
        trigger_times += 1
        print(f"EarlyStopping counter: {trigger_times} out of {patience}")
        if trigger_times >= patience:
            print("Early stopping triggered. Stopping training.")
            break

# === Plotting Accuracy and Loss ===
epochs = range(1, len(epoch_losses) + 1)
plt.figure(figsize=(14, 6))

# Accuracy plot
plt.subplot(1, 2, 1)
plt.plot(epochs, epoch_accuracies, label='Train Accuracy', marker='o')
plt.plot(epochs, val_accuracies, label='Val Accuracy', marker='x')
plt.title('Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()

# Loss plot
plt.subplot(1, 2, 2)
plt.plot(epochs, epoch_losses, label='Train Loss', marker='o')
plt.plot(epochs, val_losses, label='Val Loss', marker='x')
plt.title('Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()
