<a href="https://colab.research.google.com/github/Sunidhi-Gautam/FL_Implementation/blob/main/FL_KD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"sunidhigautam26","key":"df9c2dd0e0f5995bb0a2fd692c71ec16"}'}

In [3]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

**Imports for the Dataset Download**

In [50]:
import os
import random
import shutil
from collections import Counter
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [51]:
RAW_DATASETS_DIR = "/content/raw_teacher_datasets"
os.makedirs(RAW_DATASETS_DIR, exist_ok=True)

print("Raw datasets will be stored in:", RAW_DATASETS_DIR)

**Dataset 1: Tomato Village**

In [6]:
print("Downloading Dataset 1: Tomato Village (GitHub)")

!git clone https://github.com/mamta-joshi-gehlot/Tomato-Village.git /content/tmp_tomato_village

shutil.move(
    "/content/tmp_tomato_village",
    os.path.join(RAW_DATASETS_DIR, "dataset_1_tomato_village")
)

print("Dataset 1 ready\n")

Downloading Dataset 1: Tomato Village (GitHub)
Cloning into '/content/tmp_tomato_village'...
remote: Enumerating objects: 45041, done.[K
remote: Counting objects: 100% (26934/26934), done.[K
remote: Compressing objects: 100% (15220/15220), done.[K
remote: Total 45041 (delta 11469), reused 26919 (delta 11462), pack-reused 18107 (from 1)[K
Receiving objects: 100% (45041/45041), 3.15 GiB | 16.29 MiB/s, done.
Resolving deltas: 100% (14334/14334), done.
Updating files: 100% (53289/53289), done.
Dataset 1 ready



**Dataset 2: Kaggle 1 (Tomato Leaf)**

In [7]:
print("Downloading Dataset 2: Tomato Leaf (Kaggle)")

!kaggle datasets download -d kaustubhb999/tomatoleaf -p /content

zip_path = "/content/tomatoleaf.zip"
extract_path = "/content/tmp_tomatoleaf"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

shutil.move(
    extract_path,
    os.path.join(RAW_DATASETS_DIR, "dataset_2_tomatoleaf")
)

print("Dataset 2 ready\n")

Downloading Dataset 2: Tomato Leaf (Kaggle)
Dataset URL: https://www.kaggle.com/datasets/kaustubhb999/tomatoleaf
License(s): CC0-1.0
Downloading tomatoleaf.zip to /content
 86% 153M/179M [00:00<00:00, 1.60GB/s]
100% 179M/179M [00:00<00:00, 1.25GB/s]
Dataset 2 ready



**Dataset 3: Kaggle 2 (Tomato – Ashish Motwani)**

In [8]:
print("Downloading Dataset 3: Tomato (Kaggle - Ashish Motwani)")

!kaggle datasets download -d ashishmotwani/tomato -p /content

zip_path = "/content/tomato.zip"
extract_path = "/content/tmp_tomato_3"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

shutil.move(
    extract_path,
    os.path.join(RAW_DATASETS_DIR, "dataset_3_tomato")
)

print("Dataset 3 ready\n")

Downloading Dataset 3: Tomato (Kaggle - Ashish Motwani)
Dataset URL: https://www.kaggle.com/datasets/ashishmotwani/tomato
License(s): copyright-authors
Downloading tomato.zip to /content
 99% 1.36G/1.37G [00:05<00:00, 256MB/s]
100% 1.37G/1.37G [00:05<00:00, 280MB/s]
Dataset 3 ready



**Dataset 4: Kaggle 3 (Tomato Disease)**

In [9]:
print("Downloading Dataset 4: Tomato Diseases (Kaggle)")

!kaggle datasets download -d luisolazo/tomato-diseases -p /content

zip_path = "/content/tomato-diseases.zip"
extract_path = "/content/tmp_tomato_diseases"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

shutil.move(
    extract_path,
    os.path.join(RAW_DATASETS_DIR, "dataset_4_tomato_diseases")
)

print("Dataset 4 ready\n")

Downloading Dataset 4: Tomato Diseases (Kaggle)
Dataset URL: https://www.kaggle.com/datasets/luisolazo/tomato-diseases
License(s): CC0-1.0
Downloading tomato-diseases.zip to /content
 94% 386M/411M [00:01<00:00, 301MB/s]
100% 411M/411M [00:01<00:00, 376MB/s]
Dataset 4 ready



**Verify Downloads**

In [10]:
print("\nFinal raw datasets available:\n")

for d in sorted(os.listdir(RAW_DATASETS_DIR)):
    print("•", d)


Final raw datasets available:

• dataset_1_tomato_village
• dataset_2_tomatoleaf
• dataset_3_tomato
• dataset_4_tomato_diseases


**Canonical Class Map**

In [16]:
CLASS_MAP = {
    # Early blight
    "early_blight": "early_blight",
    "Early_blight": "early_blight",
    "Early Blight": "early_blight",
    "tomato___early_blight": "early_blight",
    "tomato__early blight": "early_blight",
    "EARLY-BLIGHT": "early_blight",


    # Late blight
    "late_blight": "late_blight",
    "Late_blight": "late_blight",
    "Late Blight": "late_blight",
    "tomato___late_blight": "late_blight",
    "tomato__late blight": "late_blight",


    # Bacterial spot
    "bacterial_spot": "bacterial_spot",
    "Bacterial Spot": "bacterial_spot",
    "Tomato___Bacterial_spot": "bacterial_spot",
    "Tomato_bacterial_spot": "bacterial_spot",


    # Leaf mold
    "leaf_mold": "leaf_mold",
    "Leaf Mold": "leaf_mold",
    "Leaf_Mold": "leaf_mold",
    "tomato___Leaf_Mold": "leaf_mold",


    # Healthy
    "healthy": "healthy",
    "Healthy": "healthy",
    "Tomato___healthy": "healthy",


    # Target spot
    "target_spot": "target_spot",
    "Target Spot": "target_spot",
    "Target_Spot": "target_spot",
    "target_spot___": "target_spot",
    "tomato___Target_Spot": "target_spot",


    # Powdery Mildew
    "powdery_mildew": "powdery_mildew",
    "Powdery Mildew": "powdery_mildew",
    "Powdery_mildew": "powdery_mildew",
    "Powdery_Mildew": "powdery_mildew",


    # Septoria Leaf Spot
    "septoria_leaf_spot": "septoria_leaf_spot",
    "Septoria Leaf Spot": "septoria_leaf_spot",
    "Septorialeafspot": "septoria_leaf_spot",
    "tomato___Septoria_leaf_spot": "septoria_leaf_spot",


    # Mosaic virus
    "mosaic_virus": "mosaic_virus",
    "Tomato_mosaic_virus": "mosaic_virus",
    "tomato_mosaic_virus": "mosaic_virus",
    "Tomato mosaic virus": "mosaic_virus",
    "tomato___Tomato_mosaic_virus": "mosaic_virus",


    # Spider mites (Two-spotted)
    "spider_mites_two_spotted_spider_mite": "spider_mites",
    "Spider Mites Two-spotted spider_mite": "spider_mites",
    "Spider_mites": "spider_mites",
    "spider_mites": "spider_mites",
    "twospotted_spider_mite": "spider_mites",
    "Tomato___Spider_mites Two-spotted_spider_mite": "spider_mites",


    # Yellow Leaf Curl Virus
    "yellow_leaf_curl_virus": "yellow_leaf_curl_virus",
    "TomatoYellowLeafCurlVirus": "yellow_leaf_curl_virus",
    "Tomato_Yellow_Leaf_Curl_Virus": "yellow_leaf_curl_virus",
    "Yellow Leaf Curl Virus": "yellow_leaf_curl_virus",
    "tomato___Tomato_Yellow_Leaf_Curl_Virus": "yellow_leaf_curl_virus",


    # Leaf Miner
    "leaf_miner": "leaf_miner",
    "Leaf Miner": "leaf_miner",
    "leaf miner": "leaf_miner",


    # Nitrogen deficiency
    "nitrogen_deficiency": "nitrogen_deficiency",
    "Nitrogen Deficiency": "nitrogen_deficiency",


    # Potassium deficiency
    "potassium_deficiency": "potassium_deficiency",
    "Pottassium Deficiency": "potassium_deficiency",
    "Potassium Deficiency": "potassium_deficiency",


    # Magnesium deficiency
    "magnesium_deficiency": "magnesium_deficiency",
    "Magnesium Deficiency": "magnesium_deficiency",


    # Spotted Wilt Virus
    "spotted_wilt_virus": "spotted_wilt_virus",
    "Spotted Wilt Virus": "spotted_wilt_virus",
    "Spotted_Wilt_Virus": "spotted_wilt_virus",
    "Spotted wilt virus": "spotted_wilt_virus",
}

In [17]:
RAW_DATASETS_DIR = "/content/raw_teacher_datasets"
TEACHER_DATASET_DIR = "/content/teacher_dataset"

os.makedirs(TEACHER_DATASET_DIR, exist_ok=True)

In [18]:
IGNORE_FOLDERS = {
    "train", "val", "test",
    "images", "image", "imgs"
}

**Normalize Class Names**

In [23]:
def normalize_class_name(raw):
    raw = raw.strip().lower()

    # Ignore structural folders
    if raw in IGNORE_FOLDERS:
        return None

    # Remove PlantVillage prefix
    if raw.startswith("tomato___"):
        raw = raw.replace("tomato___", "")

    # Cleanup
    raw = raw.replace("-", "_")
    raw = raw.replace(" ", "_")
    raw = raw.replace("__", "_")

    # ---- Canonical merges ----

    if "spider" in raw or "mite" in raw:
        return "spider_mites"

    if "yellow" in raw and "curl" in raw:
        return "yellow_leaf_curl_virus"

    if "mosaic" in raw:
        return "mosaic_virus"

    if "septoria" in raw:
        return "septoria_leaf_spot"

    if "early" in raw and "blight" in raw:
        return "early_blight"

    if "late" in raw and "blight" in raw:
        return "late_blight"

    if "target" in raw:
        return "target_spot"

    if "leaf" in raw and "mold" in raw:
        return "leaf_mold"

    # Nutrient deficiencies
    if "nitrogen" in raw:
        return "nitrogen_deficiency"

    if "pottassium" in raw or "potassium" in raw:
        return "potassium_deficiency"

    if "magnesium" in raw:
        return "magnesium_deficiency"

    # Healthy
    if raw == "healthy":
        return "healthy"

    return raw

In [24]:
import shutil, os

TEACHER_DATASET_DIR = "/content/teacher_dataset"

if os.path.exists(TEACHER_DATASET_DIR):
    shutil.rmtree(TEACHER_DATASET_DIR)

os.makedirs(TEACHER_DATASET_DIR)

**Merge & Normalizing All Datasets**

In [25]:
IMG_EXTS = (".jpg", ".jpeg", ".png")

class_counter = {}
total_images = 0

for dataset in sorted(os.listdir(RAW_DATASETS_DIR)):
    dataset_path = os.path.join(RAW_DATASETS_DIR, dataset)
    print(f"\nProcessing {dataset}")

    for root, _, files in os.walk(dataset_path):
        imgs = [f for f in files if f.lower().endswith(IMG_EXTS)]
        if not imgs:
            continue

        raw_class = os.path.basename(root)
        norm_class = normalize_class_name(raw_class)

        if norm_class is None:
            continue

        dest_cls_dir = os.path.join(TEACHER_DATASET_DIR, norm_class)
        os.makedirs(dest_cls_dir, exist_ok=True)

        for img in imgs:
            src = os.path.join(root, img)
            dst = os.path.join(dest_cls_dir, f"{dataset}_{img}")
            shutil.copy(src, dst)

            class_counter[norm_class] = class_counter.get(norm_class, 0) + 1
            total_images += 1


Processing dataset_1_tomato_village

Processing dataset_2_tomatoleaf

Processing dataset_3_tomato

Processing dataset_4_tomato_diseases


**Teacher Dataset Summary**

In [26]:
print("\n===== CLEAN TEACHER DATASET SUMMARY =====")
print(f"Total images: {total_images}")
print(f"Total classes: {len(class_counter)}\n")

for cls, cnt in sorted(class_counter.items()):
    print(f"{cls:<30} {cnt}")


===== CLEAN TEACHER DATASET SUMMARY =====
Total images: 70252
Total classes: 16

bacterial_spot                 6781
early_blight                   7315
healthy                        7499
late_blight                    8446
leaf_miner                     1024
leaf_mold                      6497
magnesium_deficiency           936
mosaic_virus                   5478
nitrogen_deficiency            360
potassium_deficiency           72
powdery_mildew                 1256
septoria_leaf_spot             6499
spider_mites                   4958
spotted_wilt_virus             517
target_spot                    4938
yellow_leaf_curl_virus         7676


**Teacher Dataset Solitting**

In [52]:
RAW_TEACHER = "/content/teacher_dataset"
SPLIT_DIR = "/content/teacher_dataset_split"

TRAIN_DIR = os.path.join(SPLIT_DIR, "train")
VAL_DIR = os.path.join(SPLIT_DIR, "val")
TEST_DIR = os.path.join(SPLIT_DIR, "test")

for d in [TRAIN_DIR, VAL_DIR, TEST_DIR]:
    if os.path.exists(d):
        shutil.rmtree(d)
    os.makedirs(d)

train_ratio = 0.7
val_ratio   = 0.15
test_ratio  = 0.15

random.seed(42)
classes = sorted(os.listdir(RAW_TEACHER))

for cls in classes:
    cls_path = os.path.join(RAW_TEACHER, cls)
    if not os.path.isdir(cls_path):
        continue
    images = os.listdir(cls_path)
    random.shuffle(images)

    n = len(images)
    n_train = int(n * train_ratio)
    n_val   = int(n * val_ratio)
    n_test  = n - n_train - n_val

    train_imgs = images[:n_train]
    val_imgs   = images[n_train:n_train+n_val]
    test_imgs  = images[n_train+n_val:]

    for folder, imgs in zip([TRAIN_DIR, VAL_DIR, TEST_DIR], [train_imgs, val_imgs, test_imgs]):
        cls_folder = os.path.join(folder, cls)
        os.makedirs(cls_folder, exist_ok=True)
        for img in imgs:
            shutil.copy(os.path.join(cls_path, img), os.path.join(cls_folder, img))

**Data Transforms & DataLoaders**

In [56]:
BATCH_SIZE = 32

train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transform)
val_dataset   = datasets.ImageFolder(VAL_DIR, transform=val_transform)
test_dataset  = datasets.ImageFolder(TEST_DIR, transform=val_transform)

# WeightedRandomSampler for class imbalance
class_counts = [0]*len(train_dataset.classes)
for _, label in train_dataset.samples:
    class_counts[label] += 1

class_sample_weights = [1.0 / class_counts[label] for _, label in train_dataset.samples]

train_sampler = WeightedRandomSampler(
    weights=class_sample_weights,
    num_samples=len(class_sample_weights),
    replacement=True
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

**Handling Class Imbalance**

**Class Frequencies**

In [60]:
train_dir = os.path.join(TEACHER_DATASET_SPLIT, "train")
class_counts = Counter()

for cls in os.listdir(train_dir):
    cls_path = os.path.join(train_dir, cls)
    class_counts[cls] = len(os.listdir(cls_path))

beta = 0.999
classes = sorted(class_counts.keys())
weights = []

for c in classes:
    n = class_counts[c]
    eff_n = 1.0 - np.power(beta, n)
    w = (1.0 - beta) / eff_n
    weights.append(w)

weights = np.array(weights)
weights = weights / weights.sum() * len(weights)
class_weights = torch.tensor(weights, dtype=torch.float).cuda()

NameError: name 'TEACHER_DATASET_SPLIT' is not defined

**Computing Class-Balanced Weights (CB-Loss)**

$$
w_c = \frac{1 - \beta}{1 - \beta^{n_c}}
$$

n_c = no. of samples in class c

β ϵ [0.9, 0.9999]

Larger β --> stronger balancing

In [58]:
import numpy as np
import torch

beta = 0.999

classes = sorted(class_counts.keys())
class_to_idx = {c: i for i, c in enumerate(classes)}

effective_num = {}
weights = []

for c in classes:
    n = class_counts[c]
    eff_n = 1.0 - np.power(beta, n)
    w = (1.0 - beta) / eff_n
    weights.append(w)
    effective_num[c] = eff_n

# Normalize weights
weights = np.array(weights)
weights = weights / weights.sum() * len(weights)

class_weights = torch.tensor(weights, dtype=torch.float)

print("\n===== CLASS WEIGHTS =====")
for c, w in zip(classes, class_weights):
    print(f"{c:<30} {w:.4f}")


===== CLASS WEIGHTS =====
bacterial_spot                 0.4604
early_blight                   0.4601
healthy                        0.4600
late_blight                    0.4598
leaf_miner                     0.7170
leaf_mold                      0.4607
magnesium_deficiency           0.7560
mosaic_virus                   0.4626
nitrogen_deficiency            1.5197
potassium_deficiency           6.6130
powdery_mildew                 0.6425
septoria_leaf_spot             0.4607
spider_mites                   0.4647
spotted_wilt_virus             1.1381
target_spot                    0.4648
yellow_leaf_curl_virus         0.4600


**Loading Teacher Dataset**

In [55]:
import timm

teacher_model = timm.create_model('efficientnet_b2', pretrained=True, num_classes=len(train_dataset.classes))
teacher_model = teacher_model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/36.8M [00:00<?, ?B/s]

**Defining Teacher Model**

In [29]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [30]:
TEACHER_DATASET = "/content/teacher_dataset"  # root folder
BATCH_SIZE = 32

In [31]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [32]:
teacher_dataset = datasets.ImageFolder(
    root=TEACHER_DATASET,
    transform=transform
)

print("Total samples:", len(teacher_dataset))
print("Classes:", teacher_dataset.classes)

Total samples: 65812
Classes: ['bacterial_spot', 'early_blight', 'healthy', 'late_blight', 'leaf_miner', 'leaf_mold', 'magnesium_deficiency', 'mosaic_virus', 'nitrogen_deficiency', 'potassium_deficiency', 'powdery_mildew', 'septoria_leaf_spot', 'spider_mites', 'spotted_wilt_virus', 'target_spot', 'yellow_leaf_curl_virus']


In [34]:
print(dict(enumerate(teacher_dataset.classes)))

{0: 'bacterial_spot', 1: 'early_blight', 2: 'healthy', 3: 'late_blight', 4: 'leaf_miner', 5: 'leaf_mold', 6: 'magnesium_deficiency', 7: 'mosaic_virus', 8: 'nitrogen_deficiency', 9: 'potassium_deficiency', 10: 'powdery_mildew', 11: 'septoria_leaf_spot', 12: 'spider_mites', 13: 'spotted_wilt_virus', 14: 'target_spot', 15: 'yellow_leaf_curl_virus'}


**Injecting Class Weights**

In [35]:
# Order MUST match teacher_dataset.classes
class_weights_dict = {
    "bacterial_spot": 0.4604,
    "early_blight": 0.4601,
    "healthy": 0.4600,
    "late_blight": 0.4598,
    "leaf_miner": 0.7170,
    "leaf_mold": 0.4607,
    "magnesium_deficiency": 0.7560,
    "mosaic_virus": 0.4626,
    "nitrogen_deficiency": 1.5197,
    "potassium_deficiency": 6.6130,
    "powdery_mildew": 0.6425,
    "septoria_leaf_spot": 0.4607,
    "spider_mites": 0.4647,
    "spotted_wilt_virus": 1.1381,
    "target_spot": 0.4648,
    "yellow_leaf_curl_virus": 0.4600
}

In [36]:
class_weights = torch.tensor(
    [class_weights_dict[c] for c in teacher_dataset.classes],
    dtype=torch.float
).cuda()   # move to GPU

In [37]:
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

**WeightedRandomSampler (Data-Level Fix)**

In [38]:
from torch.utils.data import WeightedRandomSampler

In [39]:
# Count samples per class
class_counts = [0] * len(teacher_dataset.classes)

for _, label in teacher_dataset.samples:
    class_counts[label] += 1

In [40]:
# Inverse frequency per class
class_sample_weights = [
    1.0 / class_counts[label]
    for _, label in teacher_dataset.samples
]

In [41]:
sampler = WeightedRandomSampler(
    weights=class_sample_weights,
    num_samples=len(class_sample_weights),
    replacement=True
)

In [42]:
teacher_loader = DataLoader(
    teacher_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=2,
    pin_memory=True
)

**Cap Extreme Oversampling**

In [44]:
MAX_CLASS_WEIGHT = 5.0

class_sample_weights = [
    min(1.0 / class_counts[label], MAX_CLASS_WEIGHT)
    for _, label in teacher_dataset.samples
]

**Teacher Model Training**

In [48]:
import torch
import torch.nn as nn
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained EfficientNet-B2
teacher_model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.IMAGENET1K_V1)

# Replace the classifier for 16 classes (matches your teacher dataset)
num_features = teacher_model.classifier[1].in_features
teacher_model.classifier[1] = nn.Linear(num_features, 16)

teacher_model = teacher_model.to(device)

# Optimizer
import torch.optim as optim
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)

# Criterion already defined using class_weights
# Example:
# criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

In [49]:
NUM_EPOCHS = 10  # choose what works for you

for epoch in range(NUM_EPOCHS):
    teacher_model.train()
    running_loss = 0
    correct = 0
    total = 0

    for images, labels in teacher_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = teacher_model(images)    # <- use teacher_model here
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}")

KeyboardInterrupt: 