In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast  # For mixed precision

from model import CIFAR10Net
from albumentations_transform import AlbumentationsTransform

class CIFAR10Dataset(Dataset):
    def __init__(self, data, targets, transform=None, train=True):
        self.data = data
        self.targets = targets
        self.transform = transform
        self.train = train

    def __getitem__(self, idx):
        img = self.data[idx]
        target = self.targets[idx]

        if self.transform:
            img = self.transform(img, self.train)

        return img, target

    def __len__(self):
        return len(self.data)

def train_one_epoch(model, device, train_loader, optimizer, criterion, epoch, scaler):
    model.train()
    pbar = tqdm(train_loader)
    train_loss = 0
    correct = 0
    processed = 0

    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # Mixed precision: Use autocast for the forward pass
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # Backward pass with mixed precision scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()  # Update the scale for the next iteration

        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(f'Epoch: {epoch} Loss: {loss.item():.4f} Acc: {100*correct/processed:.2f}%')

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
    return accuracy

def main():
    SEED = 42
    BATCH_SIZE = 64
    EPOCHS = 300
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.manual_seed(SEED)

    # Get transforms
    transform = AlbumentationsTransform()

    # Load CIFAR10
    trainset = datasets.CIFAR10(root='./data', train=True, download=True)
    testset = datasets.CIFAR10(root='./data', train=False, download=True)

    # Create datasets
    train_dataset = CIFAR10Dataset(trainset.data, trainset.targets, transform, train=True)
    test_dataset = CIFAR10Dataset(testset.data, testset.targets, transform, train=False)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # Initialize model, criterion, optimizer
    model = CIFAR10Net().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.15, momentum=0.95)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.05)  # StepLR scheduler

    # Mixed precision: Initialize the scaler
    scaler = GradScaler()

    # Training loop
    best_acc = 0
    for epoch in range(1, EPOCHS + 1):
        train_one_epoch(model, DEVICE, train_loader, optimizer, criterion, epoch, scaler)
        acc = test(model, DEVICE, test_loader, criterion)
        scheduler.step()  # Update the learning rate using StepLR scheduler

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_model.pth')

        # Stop training if accuracy reaches 85%
        if acc >= 85.0:
            print(f"Stopping early as accuracy reached {acc:.2f}% at epoch {epoch}.")
            break

if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified


  scaler = GradScaler()
  with autocast():
Epoch: 1 Loss: 1.6182 Acc: 37.09%: 100%|██████████| 782/782 [00:25<00:00, 30.75it/s]



Test set: Average loss: 1.4828, Accuracy: 44.95%



Epoch: 2 Loss: 1.0413 Acc: 50.91%: 100%|██████████| 782/782 [00:25<00:00, 30.24it/s]



Test set: Average loss: 1.1030, Accuracy: 60.24%



Epoch: 3 Loss: 2.0622 Acc: 56.76%: 100%|██████████| 782/782 [00:24<00:00, 31.59it/s]



Test set: Average loss: 1.4109, Accuracy: 53.88%



Epoch: 4 Loss: 0.7437 Acc: 59.81%: 100%|██████████| 782/782 [00:24<00:00, 31.77it/s]



Test set: Average loss: 0.9661, Accuracy: 65.63%



Epoch: 5 Loss: 1.1301 Acc: 62.56%: 100%|██████████| 782/782 [00:24<00:00, 32.14it/s]



Test set: Average loss: 0.9037, Accuracy: 68.43%



Epoch: 6 Loss: 0.5965 Acc: 64.87%: 100%|██████████| 782/782 [00:24<00:00, 32.11it/s]



Test set: Average loss: 0.8380, Accuracy: 71.18%



Epoch: 7 Loss: 1.0802 Acc: 66.65%: 100%|██████████| 782/782 [00:25<00:00, 31.23it/s]



Test set: Average loss: 0.7943, Accuracy: 72.33%



Epoch: 8 Loss: 1.0431 Acc: 67.95%: 100%|██████████| 782/782 [00:25<00:00, 30.88it/s]



Test set: Average loss: 0.8190, Accuracy: 71.98%



Epoch: 9 Loss: 0.9407 Acc: 68.69%: 100%|██████████| 782/782 [00:25<00:00, 30.24it/s]



Test set: Average loss: 0.7691, Accuracy: 73.41%



Epoch: 10 Loss: 0.7704 Acc: 70.36%: 100%|██████████| 782/782 [00:25<00:00, 30.12it/s]



Test set: Average loss: 0.7184, Accuracy: 75.46%



Epoch: 11 Loss: 0.9478 Acc: 71.01%: 100%|██████████| 782/782 [00:25<00:00, 30.20it/s]



Test set: Average loss: 0.6935, Accuracy: 75.54%



Epoch: 12 Loss: 0.7612 Acc: 71.91%: 100%|██████████| 782/782 [00:25<00:00, 30.15it/s]



Test set: Average loss: 0.6477, Accuracy: 77.88%



Epoch: 13 Loss: 1.2653 Acc: 72.87%: 100%|██████████| 782/782 [00:25<00:00, 30.50it/s]



Test set: Average loss: 0.6439, Accuracy: 78.31%



Epoch: 14 Loss: 0.2253 Acc: 73.46%: 100%|██████████| 782/782 [00:25<00:00, 30.37it/s]



Test set: Average loss: 0.6481, Accuracy: 77.84%



Epoch: 15 Loss: 0.4260 Acc: 73.65%: 100%|██████████| 782/782 [00:25<00:00, 30.50it/s]



Test set: Average loss: 0.5827, Accuracy: 79.64%



Epoch: 16 Loss: 1.0244 Acc: 74.40%: 100%|██████████| 782/782 [00:25<00:00, 31.21it/s]



Test set: Average loss: 0.5980, Accuracy: 79.25%



Epoch: 17 Loss: 0.8201 Acc: 74.65%: 100%|██████████| 782/782 [00:24<00:00, 31.74it/s]



Test set: Average loss: 0.5857, Accuracy: 79.86%



Epoch: 18 Loss: 0.8455 Acc: 75.05%: 100%|██████████| 782/782 [00:24<00:00, 31.92it/s]



Test set: Average loss: 0.5653, Accuracy: 80.74%



Epoch: 19 Loss: 0.4646 Acc: 75.76%: 100%|██████████| 782/782 [00:24<00:00, 31.79it/s]



Test set: Average loss: 0.5799, Accuracy: 79.79%



Epoch: 20 Loss: 1.2037 Acc: 76.09%: 100%|██████████| 782/782 [00:24<00:00, 31.77it/s]



Test set: Average loss: 0.5570, Accuracy: 81.46%



Epoch: 21 Loss: 0.6223 Acc: 76.38%: 100%|██████████| 782/782 [00:24<00:00, 31.60it/s]



Test set: Average loss: 0.5494, Accuracy: 81.39%



Epoch: 22 Loss: 0.7333 Acc: 76.99%: 100%|██████████| 782/782 [00:25<00:00, 31.06it/s]



Test set: Average loss: 0.5751, Accuracy: 80.10%



Epoch: 23 Loss: 0.6344 Acc: 77.19%: 100%|██████████| 782/782 [00:25<00:00, 30.56it/s]



Test set: Average loss: 0.5658, Accuracy: 80.53%



Epoch: 24 Loss: 0.2378 Acc: 77.30%: 100%|██████████| 782/782 [00:25<00:00, 30.41it/s]



Test set: Average loss: 0.5601, Accuracy: 81.17%



Epoch: 25 Loss: 0.5202 Acc: 77.74%: 100%|██████████| 782/782 [00:25<00:00, 30.44it/s]



Test set: Average loss: 0.5173, Accuracy: 82.54%



Epoch: 26 Loss: 0.8087 Acc: 77.99%: 100%|██████████| 782/782 [00:25<00:00, 30.57it/s]



Test set: Average loss: 0.5610, Accuracy: 81.26%



Epoch: 27 Loss: 0.3598 Acc: 78.22%: 100%|██████████| 782/782 [00:25<00:00, 30.63it/s]



Test set: Average loss: 0.5279, Accuracy: 82.30%



Epoch: 28 Loss: 0.3681 Acc: 78.18%: 100%|██████████| 782/782 [00:25<00:00, 30.50it/s]



Test set: Average loss: 0.5524, Accuracy: 81.67%



Epoch: 29 Loss: 0.8203 Acc: 78.88%: 100%|██████████| 782/782 [00:25<00:00, 30.27it/s]



Test set: Average loss: 0.5110, Accuracy: 82.94%



Epoch: 30 Loss: 0.3918 Acc: 79.10%: 100%|██████████| 782/782 [00:25<00:00, 30.41it/s]



Test set: Average loss: 0.5582, Accuracy: 81.17%



Epoch: 31 Loss: 0.7716 Acc: 81.89%: 100%|██████████| 782/782 [00:25<00:00, 31.14it/s]



Test set: Average loss: 0.4573, Accuracy: 84.75%



Epoch: 32 Loss: 0.5386 Acc: 82.39%: 100%|██████████| 782/782 [00:24<00:00, 31.80it/s]



Test set: Average loss: 0.4522, Accuracy: 84.80%



Epoch: 33 Loss: 0.4422 Acc: 82.51%: 100%|██████████| 782/782 [00:24<00:00, 31.92it/s]



Test set: Average loss: 0.4476, Accuracy: 84.98%



Epoch: 34 Loss: 0.6095 Acc: 82.94%: 100%|██████████| 782/782 [00:25<00:00, 31.28it/s]



Test set: Average loss: 0.4489, Accuracy: 85.19%

Stopping early as accuracy reached 85.19% at epoch 34.
