In [65]:
import nbimporter
from model import AlexNet
import torch
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn

In [66]:
GPUS = [0]
EPOCHS = 10
NUM_CLASSES = 100
BATCH_SIZE = 128
LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9
CHECKPOINT_DIR = 'checkpoints/'

In [67]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [68]:
seed = torch.initial_seed()
model = AlexNet(3, num_classes=NUM_CLASSES).to(device)

In [69]:
data_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), 
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [70]:
train_dataset = datasets.FashionMNIST(root=r"C:\Users\KIIT0001\Desktop\AlexNet from Scratch", train=True, transform=data_transform, download=True)
val_dataset = datasets.FashionMNIST(root=r"C:\Users\KIIT0001\Desktop\AlexNet from Scratch", train=False, transform=data_transform, download=True)

In [71]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE,)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=BATCH_SIZE,)
optim = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [72]:
criterion = nn.CrossEntropyLoss()
learningRate_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=50, gamma=0.1)

In [75]:
from tqdm import tqdm

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    train_loader_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=100)

    for X, y in train_loader_progress:
        X, y = X.to(device), y.to(device)
        optim.zero_grad()
        pred = model(X)
        loss = criterion(pred, y).to(device)
        loss.backward()
        optim.step()

        epoch_loss += loss.item()  

        _, predicted = torch.max(pred.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

        train_loader_progress.set_postfix({
            'loss': f'{loss.item():.4f}', 
            'accuracy': f'{correct/total:.4f}'
        })

    avg_loss = epoch_loss / len(train_loader)
    avg_accuracy = correct / total

    print(f"Epoch [{epoch+1}/{EPOCHS}] - loss: {avg_loss:.4f} - accuracy: {avg_accuracy:.4f}")

    if not os.path.exists(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)

    checkpoint_path = os.path.join(CHECKPOINT_DIR, f'model_checkpoint{epoch+1}.pkl')
    state = {
        'epoch': epoch,
        'optimizer': optim.state_dict(),
        'model': model.state_dict(),
        'seed': seed
    }
    torch.save(state, checkpoint_path)

Epoch 1/10: 100%|█████████████████| 469/469 [1:55:36<00:00, 14.79s/it, loss=0.4854, accuracy=0.5143]


Epoch [1/10] - loss: 1.6415 - accuracy: 0.5143


Epoch 2/10: 100%|█████████████████| 469/469 [1:37:52<00:00, 12.52s/it, loss=0.4732, accuracy=0.8288]


Epoch [2/10] - loss: 0.4598 - accuracy: 0.8288


Epoch 3/10: 100%|█████████████████| 469/469 [1:15:06<00:00,  9.61s/it, loss=0.3039, accuracy=0.8679]


Epoch [3/10] - loss: 0.3577 - accuracy: 0.8679


Epoch 4/10: 100%|█████████████████| 469/469 [1:16:44<00:00,  9.82s/it, loss=0.2910, accuracy=0.8901]


Epoch [4/10] - loss: 0.3022 - accuracy: 0.8901


Epoch 5/10: 100%|█████████████████| 469/469 [2:18:05<00:00, 17.67s/it, loss=0.1520, accuracy=0.9003]


Epoch [5/10] - loss: 0.2718 - accuracy: 0.9003


Epoch 6/10: 100%|█████████████████| 469/469 [1:16:43<00:00,  9.82s/it, loss=0.2053, accuracy=0.9080]


Epoch [6/10] - loss: 0.2493 - accuracy: 0.9080


Epoch 7/10: 100%|█████████████████| 469/469 [1:13:52<00:00,  9.45s/it, loss=0.1772, accuracy=0.9159]


Epoch [7/10] - loss: 0.2265 - accuracy: 0.9159


Epoch 8/10: 100%|█████████████████| 469/469 [1:14:43<00:00,  9.56s/it, loss=0.1543, accuracy=0.9223]


Epoch [8/10] - loss: 0.2117 - accuracy: 0.9223


Epoch 9/10: 100%|█████████████████| 469/469 [1:14:31<00:00,  9.53s/it, loss=0.1384, accuracy=0.9265]


Epoch [9/10] - loss: 0.1987 - accuracy: 0.9265


Epoch 10/10: 100%|████████████████| 469/469 [1:14:09<00:00,  9.49s/it, loss=0.1681, accuracy=0.9309]


Epoch [10/10] - loss: 0.1868 - accuracy: 0.9309


In [76]:
final_model_path = os.path.join(CHECKPOINT_DIR, 'final_model1.pt')
torch.save(model.state_dict(), final_model_path)
print(f"Model saved to {final_model_path}")

Model saved to checkpoints/final_model1.pt
