In [1]:
import csv
import torchvision
import torch.nn as nn
import torch
from torchvision import transforms,models,datasets
from torch import optim
import os
from collections import OrderedDict
import sys

In [2]:
!mkdir reduced_data
import zipfile
with zipfile.ZipFile("/content/reduced_data.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/reduced_data")

In [10]:
batch_size = 1024

train_transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

train_dataset = torchvision.datasets.ImageFolder("/content/reduced_data", transform= train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size ,shuffle=True)

test_transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

test_dataset = torchvision.datasets.ImageFolder("/content/reduced_data", transform= test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size ,shuffle=True)


In [11]:
model = models.densenet121(pretrained = True)
for params in model.parameters():
    params.requires_grad = False


classifier = nn.Sequential(OrderedDict([
    ('fc1',nn.Linear(1024,500)),
    ('relu',nn.ReLU()),
    ('fc2',nn.Linear(500,2)),
    ('Output',nn.LogSoftmax(dim=1))
]))

model.classifier = classifier
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [12]:
optimizer = optim.Adam(model.classifier.parameters())
criterian = nn.NLLLoss()
list_train_loss = []
list_test_loss = []
f = open('metrics.csv', 'w')
file = csv.writer(f)
file.writerow(['Epoch', 'Train loss', 'Test loss', 'Accuracy'])
epochs = 10
for epoch in range(epochs):
    train_loss = 0
    test_loss = 0
    for bat, (img, label) in enumerate(train_loader):
        # moving batch and lables to gpu
        img = img.to(device)
        label = label.to(device)

        model.train()
        optimizer.zero_grad()

        output = model(img)
        loss = criterian(output, label)
        loss.backward()
        optimizer.step()
        train_loss = train_loss + loss.item()
        # print(bat)

    accuracy = 0
    for bat, (img, label) in enumerate(test_loader):
        img = img.to(device)
        label = label.to(device)

        model.eval()
        logps = model(img)
        loss = criterian(logps, label)

        test_loss += loss.item()
        ps = torch.exp(logps)
        top_ps, top_class = ps.topk(1, dim=1)
        equality = top_class == label.view(*top_class.shape)
        accuracy += torch.mean(equality.type(torch.FloatTensor)).item()

    list_train_loss.append(train_loss / 20)
    list_test_loss.append(test_loss / 20)
    print('epoch: ', epoch, '    train_loss:  ', train_loss / 20, '   test_loss:    ', test_loss / 20,
          '    accuracy:  ', accuracy / len(test_loader))
    file.writerow([epoch,train_loss / 20, test_loss / 20,accuracy / len(test_loader)])


  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


epoch:  0     train_loss:   0.05865321457386017    test_loss:     0.0388381689786911     accuracy:   0.823913961648941
epoch:  1     train_loss:   0.028877216577529907    test_loss:     0.015733447670936585     accuracy:   0.974007785320282
epoch:  2     train_loss:   0.015653782337903977    test_loss:     0.011155467480421066     accuracy:   0.9686868488788605
epoch:  3     train_loss:   0.009800592064857483    test_loss:     0.012892067059874534     accuracy:   0.9468384981155396
epoch:  4     train_loss:   0.007426820322871208    test_loss:     0.007265613973140716     accuracy:   0.9763489365577698
epoch:  5     train_loss:   0.006167935580015183    test_loss:     0.00622833389788866     accuracy:   0.979704737663269
epoch:  6     train_loss:   0.005749544501304627    test_loss:     0.00649643074721098     accuracy:   0.9785276651382446
epoch:  7     train_loss:   0.004900055937469006    test_loss:     0.006786392070353031     accuracy:   0.9733440577983856
epoch:  8     train_loss

In [13]:
# torch.save(model.state_dict(), 'model.pth')
torch.save(model, "model_pytorch.h5")
f.close()