# Dependencies

In [30]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Model

In [31]:
class Alexnet(nn.Module):
    def __init__(self, num_classes=45):
        super().__init__()

        self.net = nn.Sequential(
            self.conv_blocks(3, 96, 11, 4, max_pool=True, local_normalization=True),  # b X 96 X 55 X 55 -> maxpool: b X 96 X 27 X 27
            self.conv_blocks(96, 256, 5, padding=2, max_pool=True, local_normalization=True),  # b X 256 X 27 X 27 -> maxpool: b X 256 X 13 X 13
            self.conv_blocks(256, 384, 3, padding=1),  # b X 384 X 13 X 13
            self.conv_blocks(384, 384, 3, padding=1),  # b X 384 X 13 X 13
            self.conv_blocks(384, 256, 3, padding=1, max_pool=True),  # b X 256 X 13 X 13 -> maxpool: b X 256 x 6 X 6
        )
        self.fc = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.net(x)
        x = x.view(-1, 256 * 6 * 6)
        x = self.fc(x)
        return x

    def conv_blocks(self, in_channels, out_channels, kernel_size, stride=1, padding=0, max_pool=False, local_normalization=False):
        layers = []
        layers += [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)]
        layers += [nn.ReLU()]
        if local_normalization:
            layers += [nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2)]
        if max_pool:
            layers += [nn.MaxPool2d((3, 3), stride=2)]
        return nn.Sequential(*layers)

    def init_bias(self):
        for layer in self.net:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)
        nn.init.constant_(self.net[0][0].bias, 1)
        nn.init.constant_(self.net[3][0].bias, 1)
        nn.init.constant_(self.net[4][0].bias, 1)


# HyperParameters

In [44]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMAGE_SIZE = 227
BATCH_SIZE = 64
LR = 0.0001
EPOCHS = 90
PATH = "model_checkpoint.pth"

# Dataset

In [12]:
# prompt: /content/drive/MyDrive/data/Mammals_Images.zip extract

!unzip /content/drive/MyDrive/data/Mammals_Images.zip -d /content

Archive:  /content/drive/MyDrive/data/Mammals_Images.zip
replace /content/mammals/african_elephant/african_elephant-0001.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [45]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    #transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(root='/content/mammals', transform = transform)
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers = 2, drop_last=True, pin_memory = True)

# Model Initialization



In [46]:
model = Alexnet().to(DEVICE)
model.init_bias()
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
criterion = nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training

In [47]:
best_accuracy = 0.0

for epoch in range(EPOCHS):
    lr_scheduler.step()
    epoch_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for image, label in tqdm(dataloader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)

        optimizer.zero_grad()
        score = model(image)
        loss = criterion(score, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        _, preds = torch.max(score, 1)
        correct_predictions += torch.sum(preds == label).item()
        total_samples += label.size(0)
    epoch_accuracy = correct_predictions / total_samples
    epoch_loss /= len(dataloader)

    print('Epoch: {} \tLoss: {:.4f} \tAcc: {:.4f}'.format(epoch + 1, epoch_loss, epoch_accuracy))

    # Save the model if it has the best accuracy so far
    if epoch_accuracy > best_accuracy:
        best_accuracy = epoch_accuracy
        torch.save({'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_accuracy': best_accuracy}, PATH)


  self.pid = os.fork()
100%|██████████| 214/214 [00:42<00:00,  5.07it/s]


Epoch: 1 	Loss: 3.8096 	Acc: 0.0223


100%|██████████| 214/214 [00:41<00:00,  5.11it/s]


Epoch: 2 	Loss: 3.7474 	Acc: 0.0352


100%|██████████| 214/214 [00:41<00:00,  5.10it/s]


Epoch: 3 	Loss: 3.5005 	Acc: 0.0789


100%|██████████| 214/214 [00:41<00:00,  5.16it/s]


Epoch: 4 	Loss: 3.2948 	Acc: 0.1212


100%|██████████| 214/214 [00:41<00:00,  5.10it/s]


Epoch: 5 	Loss: 3.1624 	Acc: 0.1530


100%|██████████| 214/214 [00:41<00:00,  5.17it/s]


Epoch: 6 	Loss: 3.0233 	Acc: 0.1841


100%|██████████| 214/214 [00:41<00:00,  5.22it/s]


Epoch: 7 	Loss: 2.8815 	Acc: 0.2154


100%|██████████| 214/214 [00:41<00:00,  5.14it/s]


Epoch: 8 	Loss: 2.7582 	Acc: 0.2488


100%|██████████| 214/214 [00:40<00:00,  5.24it/s]


Epoch: 9 	Loss: 2.6222 	Acc: 0.2815


100%|██████████| 214/214 [00:40<00:00,  5.25it/s]


Epoch: 10 	Loss: 2.4638 	Acc: 0.3213


100%|██████████| 214/214 [00:41<00:00,  5.17it/s]


Epoch: 11 	Loss: 2.3083 	Acc: 0.3561


100%|██████████| 214/214 [00:41<00:00,  5.12it/s]


Epoch: 12 	Loss: 2.1151 	Acc: 0.4017


100%|██████████| 214/214 [00:41<00:00,  5.19it/s]


Epoch: 13 	Loss: 1.8965 	Acc: 0.4608


100%|██████████| 214/214 [00:40<00:00,  5.25it/s]


Epoch: 14 	Loss: 1.6326 	Acc: 0.5279


100%|██████████| 214/214 [00:41<00:00,  5.17it/s]


Epoch: 15 	Loss: 1.3549 	Acc: 0.6087


100%|██████████| 214/214 [00:41<00:00,  5.13it/s]


Epoch: 16 	Loss: 1.0911 	Acc: 0.6728


100%|██████████| 214/214 [00:42<00:00,  5.09it/s]


Epoch: 17 	Loss: 0.8564 	Acc: 0.7361


100%|██████████| 214/214 [00:41<00:00,  5.19it/s]


Epoch: 18 	Loss: 0.6837 	Acc: 0.7915


100%|██████████| 214/214 [00:41<00:00,  5.18it/s]


Epoch: 19 	Loss: 0.5476 	Acc: 0.8283


100%|██████████| 214/214 [00:41<00:00,  5.10it/s]


Epoch: 20 	Loss: 0.4441 	Acc: 0.8593


100%|██████████| 214/214 [00:41<00:00,  5.10it/s]


Epoch: 21 	Loss: 0.3740 	Acc: 0.8831


100%|██████████| 214/214 [00:41<00:00,  5.12it/s]


Epoch: 22 	Loss: 0.3269 	Acc: 0.9005


100%|██████████| 214/214 [00:42<00:00,  5.09it/s]


Epoch: 23 	Loss: 0.2894 	Acc: 0.9087


100%|██████████| 214/214 [00:42<00:00,  5.07it/s]


Epoch: 24 	Loss: 0.2377 	Acc: 0.9258


100%|██████████| 214/214 [00:41<00:00,  5.16it/s]


Epoch: 25 	Loss: 0.2201 	Acc: 0.9303


100%|██████████| 214/214 [00:41<00:00,  5.19it/s]


Epoch: 26 	Loss: 0.2213 	Acc: 0.9320


100%|██████████| 214/214 [00:42<00:00,  5.09it/s]


Epoch: 27 	Loss: 0.1913 	Acc: 0.9428


100%|██████████| 214/214 [00:41<00:00,  5.11it/s]


Epoch: 28 	Loss: 0.1759 	Acc: 0.9464


100%|██████████| 214/214 [00:42<00:00,  5.06it/s]


Epoch: 29 	Loss: 0.1592 	Acc: 0.9498


100%|██████████| 214/214 [00:42<00:00,  5.07it/s]


Epoch: 30 	Loss: 0.0943 	Acc: 0.9717


100%|██████████| 214/214 [00:41<00:00,  5.10it/s]


Epoch: 31 	Loss: 0.0790 	Acc: 0.9760


100%|██████████| 214/214 [00:41<00:00,  5.11it/s]


Epoch: 32 	Loss: 0.0661 	Acc: 0.9807


100%|██████████| 214/214 [00:42<00:00,  5.06it/s]


Epoch: 33 	Loss: 0.0620 	Acc: 0.9807


100%|██████████| 214/214 [00:41<00:00,  5.15it/s]


Epoch: 34 	Loss: 0.0620 	Acc: 0.9820


100%|██████████| 214/214 [00:41<00:00,  5.13it/s]


Epoch: 35 	Loss: 0.0558 	Acc: 0.9831


100%|██████████| 214/214 [00:41<00:00,  5.17it/s]


Epoch: 36 	Loss: 0.0523 	Acc: 0.9845


100%|██████████| 214/214 [00:41<00:00,  5.11it/s]


Epoch: 37 	Loss: 0.0509 	Acc: 0.9840


100%|██████████| 214/214 [00:42<00:00,  5.07it/s]


Epoch: 38 	Loss: 0.0469 	Acc: 0.9869


100%|██████████| 214/214 [00:41<00:00,  5.16it/s]


Epoch: 39 	Loss: 0.0462 	Acc: 0.9851


100%|██████████| 214/214 [00:41<00:00,  5.14it/s]


Epoch: 40 	Loss: 0.0474 	Acc: 0.9857


100%|██████████| 214/214 [00:41<00:00,  5.13it/s]


Epoch: 41 	Loss: 0.0421 	Acc: 0.9882


100%|██████████| 214/214 [00:41<00:00,  5.16it/s]


Epoch: 42 	Loss: 0.0434 	Acc: 0.9866


100%|██████████| 214/214 [00:42<00:00,  5.09it/s]


Epoch: 43 	Loss: 0.0459 	Acc: 0.9861


100%|██████████| 214/214 [00:41<00:00,  5.13it/s]


Epoch: 44 	Loss: 0.0429 	Acc: 0.9878


100%|██████████| 214/214 [00:42<00:00,  5.07it/s]


Epoch: 45 	Loss: 0.0335 	Acc: 0.9899


100%|██████████| 214/214 [00:41<00:00,  5.12it/s]


Epoch: 46 	Loss: 0.0374 	Acc: 0.9897


 21%|██        | 44/214 [00:08<00:33,  5.02it/s]


KeyboardInterrupt: 

In [None]:
# Load model and optimizer state_dicts
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
