In [50]:
import torch
import torch.nn as nn
from torch import optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# progress bar imports
from time import sleep
from tqdm import tqdm

In [28]:
dataRoot = 'data'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [45]:
# Get caltech dataset
batch_size = 64
trainSet = datasets.CIFAR10(root=dataRoot, train=True, download=True, transform=transforms.ToTensor())
testSet = datasets.CIFAR10(root=dataRoot, train=False, download=True, transform=transforms.ToTensor())
trainLoader = DataLoader(trainSet, batch_size=batch_size, shuffle=True)
testLoader = DataLoader(testSet, batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
firstDatapoint = trainSet[0]
firstImage = firstDatapoint[0]
firstLabel = firstDatapoint[1]
print("label", firstLabel)
print("shape", firstImage.shape)

label 6
shape torch.Size([3, 32, 32])


In [41]:
class CifarNetwork(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.name = 'CifarNetwork'
        # TODO what happens when kernel size is even?
        self.conv1 = nn.Conv2d(3, 32, 3, 1) # 32x32 -> 30x30, no padding
        self.conv2 = nn.Conv2d(32, 64, 5, 1) # 30x30 -> 26x26
        
        self.batch_flattener = nn.Flatten(start_dim=1) # make flat from dim 1
        
        self.fc1 = nn.Linear(26*26*64, 10)
        self.activation = nn.LeakyReLU(0.01)
        self.LogSoftmax = nn.LogSoftmax(dim=1) # For batch, output is in dim 1 for each image.
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.batch_flattener(out)
#         print("shape after flattening", out.shape)
        out = self.fc1(out)
        out = self.activation(out)
        return self.LogSoftmax(out)
    

In [42]:
model = CifarNetwork().to(device)
print(model)
lossFunc = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

CifarNetwork(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (batch_flattener): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=43264, out_features=10, bias=True)
  (activation): LeakyReLU(negative_slope=0.01)
  (LogSoftmax): LogSoftmax(dim=1)
)


In [62]:
def saveCheckpoint(epoch, model, optimizer):
    
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
        
    }, f"{model.name}-checkpoint-{epoch}")

# train
n_epochs = 5
for epoch in range(n_epochs):
    print("Epoch:", epoch)
    epoch_loss = 0
    count = 0
    with tqdm(trainLoader, unit="batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch}")
        for images, labels in tepoch:
            count += 1
            images = images.to(device)
            labels = labels.to(device)
#             print("label shape", labels.shape)

            output = model(images)

            optimizer.zero_grad()
            batch_loss = lossFunc(output, labels)
            batch_loss.backward()
            optimizer.step()

            epoch_loss += batch_loss.item()
            predictions1 = output.argmax(dim=1)
#             predictions2 = output.argmax(dim=1, keepdim=True).squeeze()
#             print("prediction1", predictions1)
#             print("predictions2", predictions2)
            correct = (predictions == labels).sum()
            accuracy = correct / labels.size(0)

#             if count % 6 == 0:
#                 print(f"Epoch progress: {round(count * 100/len(trainLoader))}%")
                
            tepoch.set_postfix(batch_loss=batch_loss.item(), accuracy=100. * accuracy.item())
#             sleep(0.1)

        
    # checkpoint
    if epoch % 5 == 0:
        saveCheckpoint(epoch, model, optimizer)
            
    print(f"Epoch{epoch}, Average training loss: {epoch_loss / len(trainLoader)}")
    
    
        

Epoch 0:   0%|          | 1/782 [00:00<01:29,  8.77batch/s, accuracy=4.69, batch_loss=1.72]

Epoch: 0


Epoch 0:  18%|█▊        | 144/782 [00:16<01:11,  8.88batch/s, accuracy=17.2, batch_loss=1.78]


KeyboardInterrupt: 