In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler

from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold
import itertools
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [5]:
NUM_EPOCHS = 80
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
NUM_CLASSES = 10      # FashionMNIST has 10 classes
INPUT_CHANNELS = 1    # FashionMNIST is grayscale

In [6]:
image_transforms = transforms.Compose([
    transforms.Resize(224), # ResNet-18 standard input size
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5], # Grayscale mean
        std=[0.5]   # Grayscale std
    ),
])

In [7]:
# Training
train_dataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=image_transforms
)
# Testing
test_dataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=image_transforms
)

# Data Loader
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=5
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=5
)

100%|██████████| 26.4M/26.4M [00:36<00:00, 721kB/s] 
100%|██████████| 29.5k/29.5k [00:00<00:00, 178kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.09MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 7.73MB/s]


In [8]:
def train_model(model, train_loader, criterion, optimizer, num_epochs, device):
    """Handles the training process and evaluates after each epoch, tracking losses."""
    model.train()
    print("\nStarting training on FashionMNIST...")

    # Clear logs before starting
    train_losses, test_losses, test_accuracies=[],[],[]

    for epoch in range(num_epochs):
        total_loss = 0.0
        total_samples = 0

        # Training Loop
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # === sample-weighted loss, matches evaluation ===
            batch_size = labels.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

        # === Correct per-sample training loss ===
        epoch_loss = total_loss / total_samples
        train_losses.append(epoch_loss)

        # Evaluate on test set
        acc, test_loss = evaluate_model(model, test_loader, criterion, device)

        test_losses.append(test_loss)
        test_accuracies.append(acc)

        print(f"Epoch {epoch+1} finished. Train Loss: {epoch_loss:.4f} | Test Loss: {test_loss:.4f} | Test Accuracy: {acc:.2f}%")
    return train_losses, test_losses, test_accuracies


def evaluate_model(model, test_loader, criterion, device):
    """Calculates the Top-1 accuracy AND the average test loss."""
    model.eval()
    correct_predictions = 0
    total_samples = 0
    total_loss = 0.0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            batch_size = labels.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            _, predicted = outputs.max(1)
            correct_predictions += (predicted == labels).sum().item()

    accuracy = 100 * correct_predictions / total_samples
    avg_loss = total_loss / total_samples

    return accuracy, avg_loss


In [9]:
class SimpleBilinearCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),           # -> 14x14
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2)            # -> 7x7
        )
        # Bilinear pooling output: 256x256 features
        self.fc = nn.Linear(256 * 256, num_classes)
        
    def forward(self, x):
        x = self.features(x)  # [B, 256, 7, 7]
        B, C, H, W = x.size()
        x = x.view(B, C, H * W)
        # Bilinear pooling
        x = torch.bmm(x, x.transpose(1, 2)) / (H * W)  # [B, C, C]
        x = x.view(B, -1)
        x = torch.sign(x) * torch.sqrt(torch.abs(x) + 1e-10)  # signed sqrt
        x = F.normalize(x)
        return self.fc(x)

model= SimpleBilinearCNN().to(DEVICE)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE,weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=NUM_EPOCHS, gamma=0.1)
import gc
gc.collect()
torch.cuda.empty_cache()
train_losses, test_losses, test_accuracies = train_model(
    model, train_loader, criterion, optimizer, NUM_EPOCHS, DEVICE
)



Starting training on FashionMNIST...


Epoch 1/80: 100%|██████████| 1875/1875 [02:35<00:00, 12.09it/s]


Epoch 1 finished. Train Loss: 1.3349 | Test Loss: 0.9360 | Test Accuracy: 68.26%


Epoch 2/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 2 finished. Train Loss: 0.7856 | Test Loss: 0.7112 | Test Accuracy: 76.17%


Epoch 3/80: 100%|██████████| 1875/1875 [03:43<00:00,  8.38it/s]


Epoch 3 finished. Train Loss: 0.6478 | Test Loss: 0.6055 | Test Accuracy: 79.32%


Epoch 4/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.28it/s]


Epoch 4 finished. Train Loss: 0.5716 | Test Loss: 0.5681 | Test Accuracy: 80.88%


Epoch 5/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.31it/s]


Epoch 5 finished. Train Loss: 0.5250 | Test Loss: 0.5159 | Test Accuracy: 82.29%


Epoch 6/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 6 finished. Train Loss: 0.4904 | Test Loss: 0.5021 | Test Accuracy: 83.13%


Epoch 7/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 7 finished. Train Loss: 0.4643 | Test Loss: 0.4684 | Test Accuracy: 83.93%


Epoch 8/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 8 finished. Train Loss: 0.4433 | Test Loss: 0.4503 | Test Accuracy: 84.94%


Epoch 9/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 9 finished. Train Loss: 0.4233 | Test Loss: 0.4378 | Test Accuracy: 84.86%


Epoch 10/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 10 finished. Train Loss: 0.4076 | Test Loss: 0.4310 | Test Accuracy: 85.13%


Epoch 11/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 11 finished. Train Loss: 0.3941 | Test Loss: 0.4165 | Test Accuracy: 85.46%


Epoch 12/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 12 finished. Train Loss: 0.3811 | Test Loss: 0.4108 | Test Accuracy: 85.68%


Epoch 13/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 13 finished. Train Loss: 0.3678 | Test Loss: 0.4066 | Test Accuracy: 85.55%


Epoch 14/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.34it/s]


Epoch 14 finished. Train Loss: 0.3592 | Test Loss: 0.3932 | Test Accuracy: 86.11%


Epoch 15/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.26it/s]


Epoch 15 finished. Train Loss: 0.3506 | Test Loss: 0.3904 | Test Accuracy: 86.12%


Epoch 16/80: 100%|██████████| 1875/1875 [03:47<00:00,  8.23it/s]


Epoch 16 finished. Train Loss: 0.3409 | Test Loss: 0.3882 | Test Accuracy: 86.26%


Epoch 17/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 17 finished. Train Loss: 0.3326 | Test Loss: 0.3726 | Test Accuracy: 87.00%


Epoch 18/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.28it/s]


Epoch 18 finished. Train Loss: 0.3251 | Test Loss: 0.3749 | Test Accuracy: 86.74%


Epoch 19/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.34it/s]


Epoch 19 finished. Train Loss: 0.3167 | Test Loss: 0.3712 | Test Accuracy: 87.15%


Epoch 20/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 20 finished. Train Loss: 0.3111 | Test Loss: 0.3713 | Test Accuracy: 86.98%


Epoch 21/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 21 finished. Train Loss: 0.3053 | Test Loss: 0.3896 | Test Accuracy: 85.95%


Epoch 22/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.28it/s]


Epoch 22 finished. Train Loss: 0.2974 | Test Loss: 0.3564 | Test Accuracy: 87.49%


Epoch 23/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 23 finished. Train Loss: 0.2910 | Test Loss: 0.3549 | Test Accuracy: 87.54%


Epoch 24/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.26it/s]


Epoch 24 finished. Train Loss: 0.2855 | Test Loss: 0.3433 | Test Accuracy: 87.97%


Epoch 25/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.34it/s]


Epoch 25 finished. Train Loss: 0.2799 | Test Loss: 0.3518 | Test Accuracy: 87.61%


Epoch 26/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 26 finished. Train Loss: 0.2742 | Test Loss: 0.3465 | Test Accuracy: 87.61%


Epoch 27/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.26it/s]


Epoch 27 finished. Train Loss: 0.2680 | Test Loss: 0.3650 | Test Accuracy: 86.97%


Epoch 28/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 28 finished. Train Loss: 0.2624 | Test Loss: 0.3452 | Test Accuracy: 87.65%


Epoch 29/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.28it/s]


Epoch 29 finished. Train Loss: 0.2567 | Test Loss: 0.3432 | Test Accuracy: 87.92%


Epoch 30/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.34it/s]


Epoch 30 finished. Train Loss: 0.2526 | Test Loss: 0.3406 | Test Accuracy: 87.99%


Epoch 31/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 31 finished. Train Loss: 0.2468 | Test Loss: 0.3427 | Test Accuracy: 87.83%


Epoch 32/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 32 finished. Train Loss: 0.2416 | Test Loss: 0.3379 | Test Accuracy: 88.04%


Epoch 33/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 33 finished. Train Loss: 0.2357 | Test Loss: 0.3362 | Test Accuracy: 87.87%


Epoch 34/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.34it/s]


Epoch 34 finished. Train Loss: 0.2309 | Test Loss: 0.3420 | Test Accuracy: 87.66%


Epoch 35/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.31it/s]


Epoch 35 finished. Train Loss: 0.2263 | Test Loss: 0.3480 | Test Accuracy: 87.60%


Epoch 36/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.34it/s]


Epoch 36 finished. Train Loss: 0.2216 | Test Loss: 0.3401 | Test Accuracy: 87.90%


Epoch 37/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 37 finished. Train Loss: 0.2165 | Test Loss: 0.3310 | Test Accuracy: 88.08%


Epoch 38/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 38 finished. Train Loss: 0.2114 | Test Loss: 0.3300 | Test Accuracy: 88.19%


Epoch 39/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 39 finished. Train Loss: 0.2070 | Test Loss: 0.3582 | Test Accuracy: 87.05%


Epoch 40/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 40 finished. Train Loss: 0.2015 | Test Loss: 0.3433 | Test Accuracy: 87.70%


Epoch 41/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 41 finished. Train Loss: 0.1960 | Test Loss: 0.3385 | Test Accuracy: 88.03%


Epoch 42/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 42 finished. Train Loss: 0.1925 | Test Loss: 0.3389 | Test Accuracy: 88.18%


Epoch 43/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 43 finished. Train Loss: 0.1861 | Test Loss: 0.3362 | Test Accuracy: 88.10%


Epoch 44/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.28it/s]


Epoch 44 finished. Train Loss: 0.1829 | Test Loss: 0.3463 | Test Accuracy: 87.60%


Epoch 45/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 45 finished. Train Loss: 0.1788 | Test Loss: 0.3419 | Test Accuracy: 87.83%


Epoch 46/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 46 finished. Train Loss: 0.1723 | Test Loss: 0.3356 | Test Accuracy: 88.00%


Epoch 47/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 47 finished. Train Loss: 0.1680 | Test Loss: 0.3428 | Test Accuracy: 88.15%


Epoch 48/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.26it/s]


Epoch 48 finished. Train Loss: 0.1642 | Test Loss: 0.3378 | Test Accuracy: 88.09%


Epoch 49/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 49 finished. Train Loss: 0.1598 | Test Loss: 0.3397 | Test Accuracy: 87.94%


Epoch 50/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 50 finished. Train Loss: 0.1543 | Test Loss: 0.3446 | Test Accuracy: 87.62%


Epoch 51/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 51 finished. Train Loss: 0.1497 | Test Loss: 0.3460 | Test Accuracy: 88.23%


Epoch 52/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 52 finished. Train Loss: 0.1463 | Test Loss: 0.3443 | Test Accuracy: 87.97%


Epoch 53/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 53 finished. Train Loss: 0.1410 | Test Loss: 0.3625 | Test Accuracy: 87.73%


Epoch 54/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.36it/s]


Epoch 54 finished. Train Loss: 0.1372 | Test Loss: 0.3434 | Test Accuracy: 87.97%


Epoch 55/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 55 finished. Train Loss: 0.1318 | Test Loss: 0.3483 | Test Accuracy: 87.81%


Epoch 56/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.36it/s]


Epoch 56 finished. Train Loss: 0.1281 | Test Loss: 0.3482 | Test Accuracy: 87.78%


Epoch 57/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 57 finished. Train Loss: 0.1252 | Test Loss: 0.3543 | Test Accuracy: 87.45%


Epoch 58/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 58 finished. Train Loss: 0.1200 | Test Loss: 0.3519 | Test Accuracy: 87.59%


Epoch 59/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 59 finished. Train Loss: 0.1151 | Test Loss: 0.3470 | Test Accuracy: 87.81%


Epoch 60/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.27it/s]


Epoch 60 finished. Train Loss: 0.1118 | Test Loss: 0.3602 | Test Accuracy: 87.61%


Epoch 61/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.32it/s]


Epoch 61 finished. Train Loss: 0.1078 | Test Loss: 0.3505 | Test Accuracy: 87.74%


Epoch 62/80: 100%|██████████| 1875/1875 [03:47<00:00,  8.26it/s]


Epoch 62 finished. Train Loss: 0.1047 | Test Loss: 0.3717 | Test Accuracy: 87.32%


Epoch 63/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 63 finished. Train Loss: 0.1008 | Test Loss: 0.3736 | Test Accuracy: 87.41%


Epoch 64/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.26it/s]


Epoch 64 finished. Train Loss: 0.0964 | Test Loss: 0.3579 | Test Accuracy: 87.70%


Epoch 65/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 65 finished. Train Loss: 0.0924 | Test Loss: 0.3877 | Test Accuracy: 86.93%


Epoch 66/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 66 finished. Train Loss: 0.0889 | Test Loss: 0.3515 | Test Accuracy: 88.04%


Epoch 67/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 67 finished. Train Loss: 0.0866 | Test Loss: 0.3609 | Test Accuracy: 87.92%


Epoch 68/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.26it/s]


Epoch 68 finished. Train Loss: 0.0824 | Test Loss: 0.3712 | Test Accuracy: 87.24%


Epoch 69/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.28it/s]


Epoch 69 finished. Train Loss: 0.0791 | Test Loss: 0.3689 | Test Accuracy: 87.63%


Epoch 70/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 70 finished. Train Loss: 0.0752 | Test Loss: 0.3658 | Test Accuracy: 87.49%


Epoch 71/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 71 finished. Train Loss: 0.0732 | Test Loss: 0.3752 | Test Accuracy: 87.24%


Epoch 72/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.33it/s]


Epoch 72 finished. Train Loss: 0.0696 | Test Loss: 0.3809 | Test Accuracy: 87.70%


Epoch 73/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 73 finished. Train Loss: 0.0663 | Test Loss: 0.3914 | Test Accuracy: 86.93%


Epoch 74/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 74 finished. Train Loss: 0.0636 | Test Loss: 0.3807 | Test Accuracy: 87.55%


Epoch 75/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.31it/s]


Epoch 75 finished. Train Loss: 0.0599 | Test Loss: 0.3814 | Test Accuracy: 87.53%


Epoch 76/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.36it/s]


Epoch 76 finished. Train Loss: 0.0581 | Test Loss: 0.3862 | Test Accuracy: 87.59%


Epoch 77/80: 100%|██████████| 1875/1875 [03:46<00:00,  8.29it/s]


Epoch 77 finished. Train Loss: 0.0560 | Test Loss: 0.3884 | Test Accuracy: 87.67%


Epoch 78/80: 100%|██████████| 1875/1875 [03:44<00:00,  8.35it/s]


Epoch 78 finished. Train Loss: 0.0537 | Test Loss: 0.3923 | Test Accuracy: 87.35%


Epoch 79/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.30it/s]


Epoch 79 finished. Train Loss: 0.0504 | Test Loss: 0.3908 | Test Accuracy: 87.57%


Epoch 80/80: 100%|██████████| 1875/1875 [03:45<00:00,  8.31it/s]


Epoch 80 finished. Train Loss: 0.0485 | Test Loss: 0.3980 | Test Accuracy: 87.28%


In [11]:
import pandas as pd
df = pd.DataFrame({
        "epoch": list(range(1, len(train_losses) + 1)),
        "train_loss": train_losses,
        "test_loss": test_losses,
        "test_accuracy": test_accuracies,
    })
df.to_csv("BCNN_mnist.csv")