In [3]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm

# --------------------- Dataset ---------------------

class ImageArrayDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, npy_path = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            raise RuntimeError(f"Failed to open image {img_path}: {e}")
        if self.transform:
            image = self.transform(image)
        label_array = np.load(npy_path)
        assert label_array.ndim == 1, f"Invalid label shape in {npy_path}"
        label = int(np.argmax(label_array))
        return image, label

# --------------------- Data Loading ---------------------

def load_samples(folder_path):
    samples = []
    for fname in os.listdir(folder_path):
        if fname.lower().endswith(('.jpg', '.png')):
            base = os.path.splitext(fname)[0]
            img_path = os.path.join(folder_path, fname)
            npy_path = os.path.join(folder_path, base + '.npy')
            if os.path.exists(npy_path):
                samples.append((img_path, npy_path))
    return samples

def prepare_combined_loaders(folders, val_split=0.2, batch_size=16, seed=42):
    random.seed(seed)
    torch.manual_seed(seed)

    all_samples = []
    for folder in folders:
        all_samples.extend(load_samples(folder))

    random.shuffle(all_samples)
    split = int(len(all_samples) * (1 - val_split))
    train_samples = all_samples[:split]
    val_samples = all_samples[split:]

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = ImageArrayDataset(train_samples, transform=transform)
    val_dataset = ImageArrayDataset(val_samples, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader

# --------------------- Training ---------------------

def train(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs, save_path, threshold):
    best_val_acc = 0.0
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * imgs.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        train_acc = correct / total
        train_loss /= total

        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * imgs.size(0)
                correct += (outputs.argmax(1) == labels).sum().item()
                total += labels.size(0)
        val_acc = correct / total
        val_loss /= total
        scheduler.step(val_acc)

        if val_acc >= threshold and val_acc >= best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model at epoch {epoch+1} (Val Acc = {val_acc:.4f})")

        print(f"[Epoch {epoch+1}] Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# --------------------- Main ---------------------

def main():
    folders = [
        r"C:\Users\huang\Downloads\Engineering Projects\Genesys Lab\v5",
        r"C:\Users\huang\Downloads\Engineering Projects\Genesys Lab\v6"
    ]
    batch_size = 16
    epochs = 20
    val_threshold = 0.5
    seed = 42

    train_loader, val_loader = prepare_combined_loaders(folders, val_split=0.2, batch_size=batch_size, seed=seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=5).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5)

    train(model, train_loader, val_loader, criterion, optimizer, scheduler, device,
          epochs, save_path='vit_tiny_combined_best.pth', threshold=val_threshold)

if __name__ == "__main__":
    main()



[Epoch 1] Train Acc: 0.4390, Val Acc: 0.4359, Train Loss: 1.3776, Val Loss: 1.3381
[Epoch 2] Train Acc: 0.5696, Val Acc: 0.4786, Train Loss: 1.0637, Val Loss: 1.3038
[Epoch 3] Train Acc: 0.6638, Val Acc: 0.4701, Train Loss: 0.8487, Val Loss: 1.4845
[Epoch 4] Train Acc: 0.7152, Val Acc: 0.4274, Train Loss: 0.7320, Val Loss: 1.4464
[Epoch 5] Train Acc: 0.8351, Val Acc: 0.4701, Train Loss: 0.4298, Val Loss: 1.5464
Saved best model at epoch 6 (Val Acc = 0.5043)
[Epoch 6] Train Acc: 0.9593, Val Acc: 0.5043, Train Loss: 0.1579, Val Loss: 1.7744
[Epoch 7] Train Acc: 0.9850, Val Acc: 0.4957, Train Loss: 0.0499, Val Loss: 1.9086
[Epoch 8] Train Acc: 0.9850, Val Acc: 0.4872, Train Loss: 0.0565, Val Loss: 1.9697
[Epoch 9] Train Acc: 0.9936, Val Acc: 0.4274, Train Loss: 0.0269, Val Loss: 2.0762
[Epoch 10] Train Acc: 0.9979, Val Acc: 0.4188, Train Loss: 0.0120, Val Loss: 2.0949
[Epoch 11] Train Acc: 1.0000, Val Acc: 0.4359, Train Loss: 0.0061, Val Loss: 2.1304
[Epoch 12] Train Acc: 1.0000, Val Acc: