In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import gc
from tqdm import tqdm

from torch.utils.data import DataLoader

from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler

[Sketch-A-Net Architecture](https://drive.google.com/file/d/1RGQb_KeEAWLXu9sFMVdInMhzt7fDKHvl/view?usp=sharing)



In [4]:
def get_sketches():
    TRANSFORM_IMG = transforms.Compose([
        transforms.Resize((225, 225)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    DATA_DIR = "/content/drive/My Drive/MLBootcamp/Week Three/Day Fourteen/sketches"
    train_data = datasets.ImageFolder(DATA_DIR, transform=TRANSFORM_IMG)
    test_data = datasets.ImageFolder(DATA_DIR, transform=TRANSFORM_IMG)
    num_train = len(train_data)
    indices = list(range(num_train))
    valid_size = .2
    split = int(np.floor(valid_size * num_train))
    np.random.shuffle(indices)
    train_idx, test_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    trainloader = DataLoader(train_data,
                   sampler=train_sampler, batch_size=64)
    testloader = DataLoader(test_data,
                   sampler=test_sampler, batch_size=64)
    return trainloader, testloader

In [5]:
num_epochs = 40
num_classes = 5 
batch_size = 64
learning_rate = 0.001

print(num_epochs, num_classes, batch_size, learning_rate)

40 5 64 0.001


In [6]:
class Sketch_a_Net_CNN(torch.nn.Module):
    def __init__(self):
        super(Sketch_a_Net_CNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=15, stride=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.layer5 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.layer6 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=7, stride=1, padding=0),
            nn.ReLU(),
            nn.Dropout2d()
        )

        self.layer7 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Dropout2d()
        )

        self.fc1 = nn.Sequential(
            nn.Linear(512, 5)
        )

         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        return(out)

In [7]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Sketch_a_Net_CNN().to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    train_loader, val_loader = get_sketches()

    total_step = len(train_loader)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(tqdm(train_loader), 1):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # save after every epoch
        torch.save(model.state_dict(), "model.%d" % epoch)

        model.eval()

        train_correct = 0
        train_five_correct = 0
        train_total = 0
        with torch.no_grad():
            for data in tqdm(train_loader):
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                _, predicted_five = torch.topk(outputs.data, 5, dim=1)

                train_total += labels.size(0)

                train_correct += (predicted == labels).sum().item()


               # train_five_correct += (predicted_five[:, 0] == labels).sum().item()
               # train_five_correct += (predicted_five[:, 1] == labels).sum().item()
               # train_five_correct += (predicted_five[:, 2] == labels).sum().item()
               # train_five_correct += (predicted_five[:, 3] == labels).sum().item()
               # train_five_correct += (predicted_five[:, 4] == labels).sum().item()
        
        print('Top One Error of the network on train images: %d %%' % (
                100 * (1 - train_correct / train_total)))

       # print('Top Five Error of the network on train images: %d %%' % (
        #   100 * (1 - train_five_correct /train_total)))

        correct = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for data in tqdm(val_loader):
                images, labels = data

                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                _, predicted_five = torch.topk(outputs.data, 5, dim=1)

                val_total += labels.size(0)

                correct += (predicted == labels).sum().item()

              #  val_correct += (predicted_five[:, 0] == labels).sum().item()
              #  val_correct += (predicted_five[:, 1] == labels).sum().item()
              #  val_correct += (predicted_five[:, 2] == labels).sum().item()
              #  val_correct += (predicted_five[:, 3] == labels).sum().item()
              #  val_correct += (predicted_five[:, 4] == labels).sum().item()
        

        print('Top One Error of the network on validation images: %d %%' % (
                100 * (1 - correct / val_total)))


       # print('Top Five Error of the network on validation images: %d %%' % (
        #   100 * (1 - val_correct / val_total)))
        

        gc.collect()

In [8]:
train()

100%|██████████| 5/5 [01:48<00:00, 21.79s/it]
100%|██████████| 5/5 [00:04<00:00,  1.05it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 79 %


100%|██████████| 2/2 [00:27<00:00, 13.51s/it]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 81 %


100%|██████████| 5/5 [00:04<00:00,  1.05it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 80 %


100%|██████████| 2/2 [00:01<00:00,  1.71it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 80 %


100%|██████████| 5/5 [00:04<00:00,  1.05it/s]
100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 80 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 80 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 77 %


100%|██████████| 2/2 [00:01<00:00,  1.70it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 85 %


100%|██████████| 5/5 [00:04<00:00,  1.05it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 76 %


100%|██████████| 2/2 [00:01<00:00,  1.77it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 83 %


100%|██████████| 5/5 [00:04<00:00,  1.05it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 69 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 73 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 74 %


100%|██████████| 2/2 [00:01<00:00,  1.68it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 75 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 69 %


100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 76 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 52 %


100%|██████████| 2/2 [00:01<00:00,  1.70it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 56 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 53 %


100%|██████████| 2/2 [00:01<00:00,  1.77it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 66 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 57 %


100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 65 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 43 %


100%|██████████| 2/2 [00:01<00:00,  1.65it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 50 %


100%|██████████| 5/5 [00:04<00:00,  1.02it/s]
100%|██████████| 5/5 [00:04<00:00,  1.02it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 37 %


100%|██████████| 2/2 [00:01<00:00,  1.72it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 47 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 32 %


100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 46 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 24 %


100%|██████████| 2/2 [00:01<00:00,  1.70it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 40 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 44 %


100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 52 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 34 %


100%|██████████| 2/2 [00:01<00:00,  1.71it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 48 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.13it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 31 %


100%|██████████| 2/2 [00:01<00:00,  1.81it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 47 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.13it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 16 %


100%|██████████| 2/2 [00:01<00:00,  1.77it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 38 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 19 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 36 %


100%|██████████| 5/5 [00:04<00:00,  1.05it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 14 %


100%|██████████| 2/2 [00:01<00:00,  1.67it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 42 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 10 %


100%|██████████| 2/2 [00:01<00:00,  1.72it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 33 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 10 %


100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 36 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 5 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 31 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 6 %


100%|██████████| 2/2 [00:01<00:00,  1.67it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 30 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 1 %


100%|██████████| 2/2 [00:01<00:00,  1.69it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 31 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.70it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 33 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.68it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 28 %


100%|██████████| 5/5 [00:04<00:00,  1.03it/s]
100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 1 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 33 %


100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 31 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 30 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 30 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 7 %


100%|██████████| 2/2 [00:01<00:00,  1.76it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 38 %


100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
100%|██████████| 5/5 [00:04<00:00,  1.13it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 6 %


100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 32 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 5 %


100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 31 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 2 %


100%|██████████| 2/2 [00:01<00:00,  1.79it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 27 %


100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 1 %


100%|██████████| 2/2 [00:01<00:00,  1.72it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 31 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.12it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.80it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 33 %


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.77it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Top One Error of the network on validation images: 31 %


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.13it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Top One Error of the network on train images: 0 %


100%|██████████| 2/2 [00:01<00:00,  1.73it/s]

Top One Error of the network on validation images: 30 %



