In [1]:
import torch
from torchvision import datasets, transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch.nn as nn

sys.path.insert(0, '../')

trainData = datasets.CIFAR10(root='../data', train=True, download=True, transform=transforms.ToTensor())
testData = datasets.CIFAR10(root='../data', train=False, download=True, transform=transforms.ToTensor())

Files already downloaded and verified
Files already downloaded and verified


In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [3]:
trainData.data.shape

(50000, 32, 32, 3)

In [4]:
from torch.utils.data import DataLoader, Subset

# Number of samples per class for the validation set
num_classes = 10
samples_per_class = 1000
total_val_samples = num_classes * samples_per_class

# Get indices of samples for each class
class_indices = [[] for _ in range(num_classes)]
for idx, (data, label) in enumerate(trainData):
    class_indices[label].append(idx)


val_indices = []
train_indices = []
for class_idx in class_indices:
    np.random.shuffle(class_idx)
    val_indices.extend(class_idx[:samples_per_class])
    train_indices.extend(class_idx[samples_per_class:])

# Create Subsets for the train and validation datasets
train_dataset = Subset(trainData, train_indices)
val_dataset = Subset(trainData, val_indices)

# Create DataLoaders
trainLoader = DataLoader(train_dataset, batch_size=256, shuffle=True)
valLoader = DataLoader(val_dataset, batch_size=256, shuffle=False)
testLoader = DataLoader(testData, batch_size=256, shuffle=True)

# Optionally print the number of samples in each set
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

Number of training samples: 40000
Number of validation samples: 10000


In [5]:
from lxmodels import CovModel2

    
model = CovModel2()
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)
model.to(device)

CovModel2(
  (conv3): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Print()
    (3): ReLU()
    (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Print()
    (7): ReLU()
    (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Print()
    (11): ReLU()
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Print()
    (15): ReLU()
    (16): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): Print()
    (18): ReLU()
    (19): Flatten(start_dim=1, end_dim=-1)
    (20): Linear(in

In [6]:
learning_rate = 1
batch_size = 256
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 80

In [7]:
import sys
sys.path.insert(0, '../')

from util import ModelSaver

saver = ModelSaver('./')

matrix = []

def train_loop(dataloader, model, loss_fn, optimizer, ep):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        augs = v2.Compose([
            v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            v2.RandomHorizontalFlip(),
            v2.RandomRotation(5),
            v2.RandomResizedCrop(32, scale=(0.9, 1.0), ratio=(0.9, 1.0)),
        ])
        X = augs(X)
        X, y = X.to(device), y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    model.eval()
    trainLoss, trainCorrect = 0, 0
    validLoss, validCorrect = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            trainLoss += loss_fn(pred, y).item()
            trainCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

        for X, y in valLoader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            validLoss += loss_fn(pred, y).item()
            validCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

    trainLoss /= size
    trainCorrect /= size
    validLoss /= len(valLoader.dataset)
    validCorrect /= len(valLoader.dataset)
    # Multiply losses by 100
    trainLoss *= 30
    validLoss *= 30
    saver.save(model, validCorrect, "haha")
    matrix.append([trainLoss, trainCorrect, validLoss, validCorrect])
    
    
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(trainLoader, model, loss_fn, optimizer, t)
print("Done!")



Epoch 1
-------------------------------
loss: 3.344633  [  256/40000]
loss: 1.518186  [25856/40000]
Epoch 2
-------------------------------
loss: 1.380876  [  256/40000]
loss: 1.311345  [25856/40000]
Epoch 3
-------------------------------
loss: 1.142421  [  256/40000]
loss: 0.931851  [25856/40000]
Epoch 4
-------------------------------
loss: 1.055627  [  256/40000]
loss: 1.078769  [25856/40000]
Epoch 5
-------------------------------
loss: 0.909029  [  256/40000]


KeyboardInterrupt: 

In [None]:
saver.lastMax

In [None]:
model.conv3

In [None]:
type(model.conv3[1]) == nn.BatchNorm2d

In [None]:
trainFeatures, trainLabels = next(iter(trainLoader))
X = trainFeatures[0]
X = X.to(device)

model.load_state_dict(torch.load('./haha'))

model.eval()

def layerOutput(model, input, layer):
    for i in range(layer+1):
        if type(model.conv3[i]) == nn.BatchNorm2d:
            continue
        input = model.conv3[i](input)
    return input.cpu().detach().numpy()

images = layerOutput(model, X, 0)
images2 = layerOutput(model, X, 4)
images4 = layerOutput(model, X, 8)
images6 = layerOutput(model, X, 12)

def plot_images(images):
    fig, axes = plt.subplots(len(images) // 8, 8, figsize=(20, 10))

    # Loop through the images and plot them on the subplots
    for i in range(len(images)):
        row, col = i // 8, i % 8
        axes[row, col].imshow(images[i])
        axes[row, col].axis('off')
    plt.show()

plot_images(images)
plot_images(images2)
plot_images(images4)
plot_images(images6)

In [None]:
import matplotlib.pyplot as plt

# Assuming matrix is a list of lists where each list contains [trainLoss, trainCorrect, validLoss, validCorrect]
trainLosses, trainCorrects, validLosses, validCorrects = zip(*matrix)

# Create a single plot
fig, ax = plt.subplots(figsize=(15, 10))

# Plot training loss
ax.plot(trainLosses, label='Training Loss')

# Plot training accuracy
ax.plot(trainCorrects, label='Training Accuracy')

# Plot validation loss
ax.plot(validLosses, label='Validation Loss')

# Plot validation accuracy
ax.plot(validCorrects, label='Validation Accuracy')

# Add a legend
ax.legend()

# Display the plot
plt.show(fig)

In [None]:
import torchvision.models as models
from d2l import torch as d2l

alexnet = models.alexnet(weights='IMAGENET1K_V1')
alexnet.to(device)

size = len(valLoader.dataset)
alexnet.eval()
validLoss, validCorrect = 0, 0
with torch.no_grad():
    for X, y in valLoader:
        X, y = X.to(device), y.to(device)
        pred = alexnet(X)
        validLoss += loss_fn(pred, y).item()
        validCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

validLoss /= len(valLoader.dataset)
validCorrect /= len(valLoader.dataset)
# Multiply losses by 100
validLoss *= 30
print(f"Validation Error: \n Accuracy: {(100*validCorrect):>0.1f}%, Avg loss: {validLoss:>8f} \n")