In [1]:
import torch
from torchvision import datasets, transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch.nn as nn
from torch.utils.data import DataLoader, Subset


sys.path.insert(0, '../')

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [3]:
## STL10

trainData = datasets.STL10(root='../data', download=True, transform=transforms.ToTensor())
testData = datasets.STL10(root='../data', split='test', download=True, transform=transforms.ToTensor())

# Combine train and test datasets
combined_dataset = torch.utils.data.ConcatDataset([trainData, testData])

# Find indices for each class
targets = np.array([y for _, y in combined_dataset])
train_indices = []
valid_indices = []

num_classes = 10
num_valid_per_class = 200

for class_idx in range(num_classes):
    class_indices = np.where(targets == class_idx)[0]
    np.random.shuffle(class_indices)
    valid_indices.extend(class_indices[:num_valid_per_class])
    train_indices.extend(class_indices[num_valid_per_class:])

# Create Subsets
train_dataset = Subset(combined_dataset, train_indices)
valid_dataset = Subset(combined_dataset, valid_indices)

# Create DataLoaders
trainLoader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
valLoader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [8]:
trainData.data.shape

(5000, 3, 96, 96)

In [5]:
from lxmodels import CovModel3

    
model = CovModel3()
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)
model.to(device)

CovModel3(
  (conv3): Sequential(
    (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Print()
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Print()
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Print()
    (13): ReLU()
    (14): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): Print()
    (17): ReLU()
 

In [6]:
learning_rate = 1
batch_size = 64
epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 80

In [7]:
import sys
sys.path.insert(0, '../')

from util import ModelSaver

saver = ModelSaver('./')

matrix = []

def train_loop(dataloader, model, loss_fn, optimizer, ep):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        augs = v2.Compose([
            v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            v2.RandomHorizontalFlip(),
            v2.RandomRotation(5),
            v2.RandomResizedCrop(32, scale=(0.9, 1.0), ratio=(0.9, 1.0)),
        ])
        X = augs(X)
        X, y = X.to(device), y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    model.eval()
    trainLoss, trainCorrect = 0, 0
    validLoss, validCorrect = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            trainLoss += loss_fn(pred, y).item()
            trainCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

        for X, y in valLoader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            validLoss += loss_fn(pred, y).item()
            validCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

    trainLoss /= size
    trainCorrect /= size
    validLoss /= len(valLoader.dataset)
    validCorrect /= len(valLoader.dataset)
    # Multiply losses by 100
    trainLoss *= 30
    validLoss *= 30
    saver.save(model, validCorrect, "haha")
    matrix.append([trainLoss, trainCorrect, validLoss, validCorrect])
    
    
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(trainLoader, model, loss_fn, optimizer, t)
print("Done!")



Epoch 1
-------------------------------
torch.Size([64, 32, 32, 32]) 1
torch.Size([64, 64, 16, 16]) 2
torch.Size([64, 64, 8, 8]) 3
torch.Size([64, 128, 8, 8]) 4
torch.Size([64, 128, 8, 8]) 5


RuntimeError: linear(): input and weight.T shapes cannot be multiplied (64x8192 and 73728x512)

In [None]:
saver.lastMax

In [None]:
model.conv3

In [None]:
type(model.conv3[1]) == nn.BatchNorm2d

In [None]:
trainFeatures, trainLabels = next(iter(trainLoader))
X = trainFeatures[0]
# X = v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5)(X)
# plt.imshow(X.permute(1, 2, 0))
# plt.show()
X = X.to(device)

model.load_state_dict(torch.load('./haha'))

# model.eval()

def layerOutput(model, input, layer):
    for i in range(layer+1):
        if type(model.conv3[i]) == nn.BatchNorm2d:
            continue
        input = model.conv3[i](input)
    return input.cpu().detach().numpy()

# out1 = model.conv3[0](X)  # Conv2d
# out2 = model.conv3[3](out1)  # ReLU
# out3 = model.conv3[2](out2)  # Conv2d
# out4 = model.conv3[3](out3)  # ReLU
# out5 = model.conv3[4](out4)  # Conv2d
# out6 = model.conv3[5](out5)  # ReLU
# out7 = model.conv3[6](out4)  # Conv2d

# images = model.conv3[0](X).cpu().detach().numpy()
# images2 = out2.cpu().detach().numpy()
# images4 = out4.cpu().detach().numpy()
# images6 = out6.cpu().detach().numpy()
images = layerOutput(model, X, 0)
images2 = layerOutput(model, X, 4)
images4 = layerOutput(model, X, 8)
images6 = layerOutput(model, X, 12)

def plot_images(images):
    fig, axes = plt.subplots(len(images) // 8, 8, figsize=(20, 10))

    # Loop through the images and plot them on the subplots
    for i in range(len(images)):
        row, col = i // 8, i % 8
        axes[row, col].imshow(images[i])
        axes[row, col].axis('off')
    plt.show()

plot_images(images)
plot_images(images2)
plot_images(images4)
plot_images(images6)

In [None]:
import matplotlib.pyplot as plt

# Assuming matrix is a list of lists where each list contains [trainLoss, trainCorrect, validLoss, validCorrect]
trainLosses, trainCorrects, validLosses, validCorrects = zip(*matrix)

# Create a single plot
fig, ax = plt.subplots(figsize=(15, 10))

# Plot training loss
ax.plot(trainLosses, label='Training Loss')

# Plot training accuracy
ax.plot(trainCorrects, label='Training Accuracy')

# Plot validation loss
ax.plot(validLosses, label='Validation Loss')

# Plot validation accuracy
ax.plot(validCorrects, label='Validation Accuracy')

# Add a legend
ax.legend()

# Display the plot
plt.show(fig)

In [None]:
import torchvision.models as models
from d2l import torch as d2l

alexnet = models.alexnet(weights='IMAGENET1K_V1')
alexnet.to(device)

size = len(valLoader.dataset)
alexnet.eval()
validLoss, validCorrect = 0, 0
with torch.no_grad():
    for X, y in valLoader:
        X, y = X.to(device), y.to(device)
        pred = alexnet(X)
        validLoss += loss_fn(pred, y).item()
        validCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

validLoss /= len(valLoader.dataset)
validCorrect /= len(valLoader.dataset)
# Multiply losses by 100
validLoss *= 30
print(f"Validation Error: \n Accuracy: {(100*validCorrect):>0.1f}%, Avg loss: {validLoss:>8f} \n")