# IMPORTS
- note : do not run 1 model's training more than once without restarting kernel as the memory runs out after 1 iteration

In [1]:
# ===============================
# 0️⃣ Imports & Setup
# ===============================
import h5py
import numpy as np
import torch
from typing import List, Optional
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
%load_ext autoreload
%autoreload 2
from dataset import SASDataset , Data_loader
from segformer_model import segformer_SAS_2channel
from UNet import UNet
from train import compute_metrics, train_model_segformer, train_model_UNet
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
!nvidia-smi

Wed Nov 19 08:46:42 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.4     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:07:00.0 Off |                    0 |
| N/A   43C    P0             262W / 400W |   1100MiB / 40960MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# LOADING DATASET (BATCH SIZE = 3 WORKS WITHOUT THROWING OUT OF MEMORY ERROR)

In [2]:
train_loader , val_loader , test_loader = Data_loader()

2-channel images: (129, 2, 1001, 1001)
Masks: (129, 1001, 1001)
Train: 96 | Val: 24 | Test: 9


# SEGFORMER(B5 USED) TRAINING

In [None]:
model = segformer_SAS_2channel(num_class = 9)
model.to(DEVICE)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)

print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {non_trainable_params:,}")
print(f"Total parameters: {trainable_params + non_trainable_params:,}")

In [13]:
train_model_segformer(model, train_loader, val_loader, NUM_CLASSES=9, NUM_EPOCHS=500, LR=1e-5, WEIGHT_DECAY=0.01)

# UNET TRAINING 

In [8]:
model = UNet(n_channels=2,n_classes=9, n_out_channels=64).to(DEVICE)
# ==== Model / Loss / Optimizer ====
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)

print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {non_trainable_params:,}")
print(f"Total parameters: {trainable_params + non_trainable_params:,}")



In [9]:
train_model_UNet(model, train_loader, val_loader, NUM_CLASSES=9, NUM_EPOCHS=500, LR=1e-5, WEIGHT_DECAY=0.01)

Trainable parameters: 31,043,465
Non-trainable parameters: 0
Total parameters: 31,043,465
