In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import files
files.upload()

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

**Imports for the Dataset Download**

In [None]:
import os
import random
import shutil
import numpy as np
import torch
import timm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import json

from sklearn.metrics import classification_report
from tqdm import tqdm
from collections import Counter
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
RAW_DATASETS_DIR = "/content/raw_teacher_datasets"
os.makedirs(RAW_DATASETS_DIR, exist_ok=True)

print("Raw datasets will be stored in:", RAW_DATASETS_DIR)

**Dataset 1: Tomato Village**

In [None]:
print("Downloading Dataset 1: Tomato Village (GitHub)")

!git clone https://github.com/mamta-joshi-gehlot/Tomato-Village.git /content/tmp_tomato_village

shutil.move(
    "/content/tmp_tomato_village",
    os.path.join(RAW_DATASETS_DIR, "dataset_1_tomato_village")
)

print("Dataset 1 ready\n")

**Dataset 2: Kaggle 1 (Tomato Leaf)**

In [None]:
import zipfile

print("Downloading Dataset 2: Tomato Leaf (Kaggle)")

!kaggle datasets download -d kaustubhb999/tomatoleaf -p /content

zip_path = "/content/tomatoleaf.zip"
extract_path = "/content/tmp_tomatoleaf"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

shutil.move(
    extract_path,
    os.path.join(RAW_DATASETS_DIR, "dataset_2_tomatoleaf")
)

print("Dataset 2 ready\n")

**Dataset 3: Kaggle 2 (Tomato â€“ Ashish Motwani)**

In [None]:
print("Downloading Dataset 3: Tomato (Kaggle - Ashish Motwani)")

!kaggle datasets download -d ashishmotwani/tomato -p /content

zip_path = "/content/tomato.zip"
extract_path = "/content/tmp_tomato_3"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

shutil.move(
    extract_path,
    os.path.join(RAW_DATASETS_DIR, "dataset_3_tomato")
)

print("Dataset 3 ready\n")

**Dataset 4: Kaggle 3 (Tomato Disease)**

In [None]:
print("Downloading Dataset 4: Tomato Diseases (Kaggle)")

!kaggle datasets download -d luisolazo/tomato-diseases -p /content

zip_path = "/content/tomato-diseases.zip"
extract_path = "/content/tmp_tomato_diseases"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

shutil.move(
    extract_path,
    os.path.join(RAW_DATASETS_DIR, "dataset_4_tomato_diseases")
)

print("Dataset 4 ready\n")

**Verify Downloads**

In [None]:
print("\nFinal raw datasets available:\n")

for d in sorted(os.listdir(RAW_DATASETS_DIR)):
    print("â€¢", d)

**Canonical Class Map**

In [None]:
CLASS_MAP = {
    # Early blight
    "early_blight": "early_blight",
    "Early_blight": "early_blight",
    "Early Blight": "early_blight",
    "tomato___early_blight": "early_blight",
    "tomato__early blight": "early_blight",
    "EARLY-BLIGHT": "early_blight",


    # Late blight
    "late_blight": "late_blight",
    "Late_blight": "late_blight",
    "Late Blight": "late_blight",
    "tomato___late_blight": "late_blight",
    "tomato__late blight": "late_blight",


    # Bacterial spot
    "bacterial_spot": "bacterial_spot",
    "Bacterial Spot": "bacterial_spot",
    "Tomato___Bacterial_spot": "bacterial_spot",
    "Tomato_bacterial_spot": "bacterial_spot",


    # Leaf mold
    "leaf_mold": "leaf_mold",
    "Leaf Mold": "leaf_mold",
    "Leaf_Mold": "leaf_mold",
    "tomato___Leaf_Mold": "leaf_mold",


    # Healthy
    "healthy": "healthy",
    "Healthy": "healthy",
    "Tomato___healthy": "healthy",


    # Target spot
    "target_spot": "target_spot",
    "Target Spot": "target_spot",
    "Target_Spot": "target_spot",
    "target_spot___": "target_spot",
    "tomato___Target_Spot": "target_spot",


    # Powdery Mildew
    "powdery_mildew": "powdery_mildew",
    "Powdery Mildew": "powdery_mildew",
    "Powdery_mildew": "powdery_mildew",
    "Powdery_Mildew": "powdery_mildew",


    # Septoria Leaf Spot
    "septoria_leaf_spot": "septoria_leaf_spot",
    "Septoria Leaf Spot": "septoria_leaf_spot",
    "Septorialeafspot": "septoria_leaf_spot",
    "tomato___Septoria_leaf_spot": "septoria_leaf_spot",


    # Mosaic virus
    "mosaic_virus": "mosaic_virus",
    "Tomato_mosaic_virus": "mosaic_virus",
    "tomato_mosaic_virus": "mosaic_virus",
    "Tomato mosaic virus": "mosaic_virus",
    "tomato___Tomato_mosaic_virus": "mosaic_virus",


    # Spider mites (Two-spotted)
    "spider_mites_two_spotted_spider_mite": "spider_mites",
    "Spider Mites Two-spotted spider_mite": "spider_mites",
    "Spider_mites": "spider_mites",
    "spider_mites": "spider_mites",
    "twospotted_spider_mite": "spider_mites",
    "Tomato___Spider_mites Two-spotted_spider_mite": "spider_mites",


    # Yellow Leaf Curl Virus
    "yellow_leaf_curl_virus": "yellow_leaf_curl_virus",
    "TomatoYellowLeafCurlVirus": "yellow_leaf_curl_virus",
    "Tomato_Yellow_Leaf_Curl_Virus": "yellow_leaf_curl_virus",
    "Yellow Leaf Curl Virus": "yellow_leaf_curl_virus",
    "tomato___Tomato_Yellow_Leaf_Curl_Virus": "yellow_leaf_curl_virus",


    # Leaf Miner
    "leaf_miner": "leaf_miner",
    "Leaf Miner": "leaf_miner",
    "leaf miner": "leaf_miner",


    # Nitrogen deficiency
    "nitrogen_deficiency": "nitrogen_deficiency",
    "Nitrogen Deficiency": "nitrogen_deficiency",


    # Potassium deficiency
    "potassium_deficiency": "potassium_deficiency",
    "Pottassium Deficiency": "potassium_deficiency",
    "Potassium Deficiency": "potassium_deficiency",


    # Magnesium deficiency
    "magnesium_deficiency": "magnesium_deficiency",
    "Magnesium Deficiency": "magnesium_deficiency",


    # Spotted Wilt Virus
    "spotted_wilt_virus": "spotted_wilt_virus",
    "Spotted Wilt Virus": "spotted_wilt_virus",
    "Spotted_Wilt_Virus": "spotted_wilt_virus",
    "Spotted wilt virus": "spotted_wilt_virus",
}

In [None]:
RAW_DATASETS_DIR = "/content/raw_teacher_datasets"
TEACHER_DATASET_DIR = "/content/teacher_dataset"

os.makedirs(TEACHER_DATASET_DIR, exist_ok=True)

In [None]:
IGNORE_FOLDERS = {
    "train", "val", "test",
    "images", "image", "imgs"
}

**Normalize Class Names**

In [None]:
def normalize_class_name(raw):
    raw = raw.strip().lower()

    # Ignore structural folders
    if raw in IGNORE_FOLDERS:
        return None

    # Remove PlantVillage prefix
    if raw.startswith("tomato___"):
        raw = raw.replace("tomato___", "")

    # Cleanup
    raw = raw.replace("-", "_")
    raw = raw.replace(" ", "_")
    raw = raw.replace("__", "_")

    # ---- Canonical merges ----

    if "spider" in raw or "mite" in raw:
        return "spider_mites"

    if "yellow" in raw and "curl" in raw:
        return "yellow_leaf_curl_virus"

    if "mosaic" in raw:
        return "mosaic_virus"

    if "septoria" in raw:
        return "septoria_leaf_spot"

    if "early" in raw and "blight" in raw:
        return "early_blight"

    if "late" in raw and "blight" in raw:
        return "late_blight"

    if "target" in raw:
        return "target_spot"

    if "leaf" in raw and "mold" in raw:
        return "leaf_mold"

    # Nutrient deficiencies
    if "nitrogen" in raw:
        return "nitrogen_deficiency"

    if "pottassium" in raw or "potassium" in raw:
        return "potassium_deficiency"

    if "magnesium" in raw:
        return "magnesium_deficiency"

    # Healthy
    if raw == "healthy":
        return "healthy"

    return raw

In [None]:
TEACHER_DATASET_DIR = "/content/teacher_dataset"

if os.path.exists(TEACHER_DATASET_DIR):
    shutil.rmtree(TEACHER_DATASET_DIR)

os.makedirs(TEACHER_DATASET_DIR)

**Merge & Normalizing All Datasets**

In [None]:
IMG_EXTS = (".jpg", ".jpeg", ".png")

class_counter = {}
total_images = 0

for dataset in sorted(os.listdir(RAW_DATASETS_DIR)):
    dataset_path = os.path.join(RAW_DATASETS_DIR, dataset)
    print(f"\nProcessing {dataset}")

    for root, _, files in os.walk(dataset_path):
        imgs = [f for f in files if f.lower().endswith(IMG_EXTS)]
        if not imgs:
            continue

        raw_class = os.path.basename(root)
        norm_class = normalize_class_name(raw_class)

        if norm_class is None:
            continue

        dest_cls_dir = os.path.join(TEACHER_DATASET_DIR, norm_class)
        os.makedirs(dest_cls_dir, exist_ok=True)

        for img in imgs:
            src = os.path.join(root, img)
            dst = os.path.join(dest_cls_dir, f"{dataset}_{img}")
            shutil.copy(src, dst)

            class_counter[norm_class] = class_counter.get(norm_class, 0) + 1
            total_images += 1

**Teacher Dataset Summary**

In [None]:
print("\n===== CLEAN TEACHER DATASET SUMMARY =====")
print(f"Total images: {total_images}")
print(f"Total classes: {len(class_counter)}\n")

for cls, cnt in sorted(class_counter.items()):
    print(f"{cls:<30} {cnt}")

**Teacher Dataset Splitting**

In [None]:
RAW_TEACHER = "/content/teacher_dataset"
SPLIT_DIR = "/content/teacher_dataset_split"

TRAIN_DIR = os.path.join(SPLIT_DIR, "train")
VAL_DIR = os.path.join(SPLIT_DIR, "val")
TEST_DIR = os.path.join(SPLIT_DIR, "test")

for d in [TRAIN_DIR, VAL_DIR, TEST_DIR]:
    if os.path.exists(d):
        shutil.rmtree(d)
    os.makedirs(d)

train_ratio = 0.7
val_ratio   = 0.15
test_ratio  = 0.15

random.seed(42)
classes = sorted(os.listdir(RAW_TEACHER))

for cls in classes:
    cls_path = os.path.join(RAW_TEACHER, cls)
    if not os.path.isdir(cls_path):
        continue
    images = os.listdir(cls_path)
    random.shuffle(images)

    n = len(images)
    n_train = int(n * train_ratio)
    n_val   = int(n * val_ratio)
    n_test  = n - n_train - n_val

    train_imgs = images[:n_train]
    val_imgs   = images[n_train:n_train+n_val]
    test_imgs  = images[n_train+n_val:]

    for folder, imgs in zip([TRAIN_DIR, VAL_DIR, TEST_DIR], [train_imgs, val_imgs, test_imgs]):
        cls_folder = os.path.join(folder, cls)
        os.makedirs(cls_folder, exist_ok=True)
        for img in imgs:
            shutil.copy(os.path.join(cls_path, img), os.path.join(cls_folder, img))

**Data Transforms & DataLoaders**

In [None]:
BATCH_SIZE = 32

train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transform)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=val_transform)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=val_transform)

**Handling Class Imbalance**

**Class Frequencies**

In [None]:
# Path to train folder
train_dir = os.path.join(SPLIT_DIR, "train")

# Count samples per class in training set
class_counts = Counter()
for cls in os.listdir(train_dir):
    cls_path = os.path.join(train_dir, cls)
    if not os.path.isdir(cls_path):
        continue
    class_counts[cls] = len(os.listdir(cls_path))

print("===== CLASS FREQUENCIES (Train Set) =====")
for c, n in sorted(class_counts.items()):
    print(f"{c:<30} {n}")

**Computing Class-Balanced Weights (CB-Loss)**

$$
w_c = \frac{1 - \beta}{1 - \beta^{n_c}}
$$

n_c = no. of samples in class c

Î² Ïµ [0.9, 0.9999]

Larger Î² --> stronger balancing

In [None]:
class_counts = [0]*len(train_dataset.classes)
for _, label in train_dataset.samples:
    class_counts[label] += 1

beta = 0.999
effective_num = [1 - beta**n for n in class_counts]
weights = [(1-beta)/n for n in effective_num]
weights = np.array(weights)
weights = weights / weights.sum() * len(weights)
class_weights = torch.tensor(weights, dtype=torch.float).cuda()

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

**WeightedRandomSampler***

In [None]:
MAX_CLASS_WEIGHT = 5.0
sample_weights = [min(1.0 / class_counts[label], MAX_CLASS_WEIGHT) for _, label in train_dataset.samples]

train_sampler = WeightedRandomSampler(weights=sample_weights,
                                      num_samples=len(sample_weights),
                                      replacement=True)

**DataLoaders**

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

**Loading Teacher Dataset**

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher_model = timm.create_model('efficientnet_b2', pretrained=True, num_classes=len(train_dataset.classes))
teacher_model = teacher_model.to(device)

optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)

**Training Loop**

In [None]:
NUM_EPOCHS = 10
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

    # --- Training ---
    teacher_model.train()
    running_loss, correct, total = 0, 0, 0

    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
    for images, labels in train_loader_tqdm:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = teacher_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # Update live progress bar with current batch accuracy & loss
        train_loader_tqdm.set_postfix({
            "Loss": f"{running_loss/total:.4f}",
            "Acc": f"{correct/total:.4f}"
        })

    train_loss = running_loss / total
    train_acc = correct / total

    # --- Validation ---
    teacher_model.eval()
    val_running_loss, val_correct, val_total = 0, 0, 0

    val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)
    for images, labels in val_loader_tqdm:
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            outputs = teacher_model(images)
            loss = criterion(outputs, labels)

        val_running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        val_correct += (preds == labels).sum().item()
        val_total += labels.size(0)

        val_loader_tqdm.set_postfix({
            "Loss": f"{val_running_loss/val_total:.4f}",
            "Acc": f"{val_correct/val_total:.4f}"
        })

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    # --- Save history ---
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    # --- Epoch summary ---
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

# Save history
torch.save(history, "teacher_history.pt")


In [None]:
torch.save(teacher_model.state_dict(), "teacher_efficientnet_b2.pth")

In [None]:
# --- Plot training & validation curves ---
history = torch.load("teacher_history.pt")  # Load your saved history

epochs = range(1, len(history['train_loss']) + 1)

plt.figure(figsize=(14,5))

# Loss
plt.subplot(1,2,1)
plt.plot(epochs, history['train_loss'], label='Train Loss', marker='o')
plt.plot(epochs, history['val_loss'], label='Val Loss', marker='o')
plt.title("Training & Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()

# Accuracy
plt.subplot(1,2,2)
plt.plot(epochs, history['train_acc'], label='Train Acc', marker='o')
plt.plot(epochs, history['val_acc'], label='Val Acc', marker='o')
plt.title("Training & Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.grid(True)
plt.legend()

plt.show()

In [None]:
# --- Evaluate on test set ---
teacher_model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = teacher_model(images)
        _, preds = torch.max(outputs, 1)
        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)

test_acc = test_correct / test_total
print(f"Test Accuracy: {test_acc*100:.2f}%")

In [None]:
teacher_model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = teacher_model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("\n===== CLASSIFICATION REPORT (TEST SET) =====\n")
print(
    classification_report(
        all_labels,
        all_preds,
        target_names=test_dataset.classes,
        digits=4
    )
)

In [None]:
with open("teacher_classes.json", "w") as f:
    json.dump(train_dataset.classes, f)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os, json, torch

SAVE_DIR = "/content/drive/MyDrive/teacher_model"
os.makedirs(SAVE_DIR, exist_ok=True)

# save model
torch.save(
    teacher_model.state_dict(),
    f"{SAVE_DIR}/teacher_efficientnet_b2.pth"
)

# save classes
with open(f"{SAVE_DIR}/teacher_classes.json", "w") as f:
    json.dump(train_dataset.classes, f)

In [None]:
experiment_info = {
    "dataset": "tomato-leaf-disease",
    "source": "Kaggle",
    "num_classes": len(train_dataset.classes),
    "classes": train_dataset.classes,
    "input_size": 260,
    "model": "efficientnet_b2",
    "epochs": 10,
    "optimizer": "AdamW",
    "best_val_acc": 0.9832
}

import json
with open("/content/drive/MyDrive/teacher_model/experiment.json", "w") as f:
    json.dump(experiment_info, f, indent=4)

In [None]:
from google.colab import files
files.upload()   # upload .pth and .json

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("teacher_classes.json") as f:
    classes = json.load(f)

teacher_model = timm.create_model(
    'efficientnet_b2',
    pretrained=False,        # IMPORTANT
    num_classes=len(classes)
)

teacher_model.load_state_dict(
    torch.load("teacher_efficientnet_b2.pth", map_location=device)
)

teacher_model = teacher_model.to(device)
teacher_model.eval()

**Saving the Datasets**

In [None]:
!rm -rf /content/drive/MyDrive/teacher_dataset

In [None]:
!mkdir -p /content/drive/MyDrive/teacher_dataset

In [None]:
!cp -r /content/teacher_dataset/* /content/drive/MyDrive/teacher_dataset/

In [None]:
SRC_SPLIT_DIR = "/content/teacher_dataset_split"
DRIVE_SPLIT_DIR = "/content/drive/MyDrive/teacher_dataset_split"

In [None]:
import shutil
import os

if os.path.exists(DRIVE_SPLIT_DIR):
    shutil.rmtree(DRIVE_SPLIT_DIR)

os.makedirs(DRIVE_SPLIT_DIR)
print("ðŸ§¹ Old broken split removed")

In [None]:
shutil.copytree(SRC_SPLIT_DIR, DRIVE_SPLIT_DIR, dirs_exist_ok=True)
print("âœ… Teacher split dataset saved to Drive correctly")

In [None]:
!cp teacher_history.pt /content/drive/MyDrive/

In [None]:
!ls /content/drive/MyDrive/teacher_dataset

In [None]:
!find /content/drive/MyDrive/teacher_dataset -maxdepth 1 -type d

In [None]:
!ls /content/drive/MyDrive/teacher_dataset_split

In [None]:
!find /content/drive/MyDrive/teacher_dataset_split/ -maxdepth 2 -type d

**Reloading From Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
TEACHER_DATASET_DIR = "/content/drive/MyDrive/teacher_dataset"
SPLIT_DIR = "/content/drive/MyDrive/teacher_dataset_split"

In [None]:
teacher_model.load_state_dict(
    torch.load("/content/drive/MyDrive/teacher_efficientnet_b2.pth")
)

In [None]:
teacher_model = teacher_model.to(device)
teacher_model.eval()

In [None]:
from sklearn.metrics import confusion_matrix
teacher_model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = teacher_model(images)
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [None]:
class_names = test_dataset.classes   # safest
print(class_names)

In [None]:
cm = confusion_matrix(all_labels, all_preds)

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names
)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix â€“ Teacher Model (Test Set)")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)

plt.figure(figsize=(14, 12))
sns.heatmap(
    cm_norm,
    annot=True,
    fmt=".2f",
    cmap="Greens",
    xticklabels=class_names,
    yticklabels=class_names
)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Confusion Matrix â€“ Teacher Model (Test Set)")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()