In [None]:
import torch
import torchvision
from torch.utils.data import Dataset
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn.functional as F
import matplotlib.pylab as plt
import numpy as np
from torchvision import models
torch.manual_seed(2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available')
else:
    print('CUDA is available!')

In [None]:
from google.colab import drive, files
drive.mount('/content/drive/')
trainpath = "/content/drive/My Drive/dataset/"
validationpath = "/content/drive/My Drive/datasetT/"

In [None]:
#Test Transform and data loader
transform1 = transforms.Compose([transforms.Resize((224,224)),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
train_dataset = dsets.ImageFolder(root=trainpath, transform=transform1)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=32, shuffle=True)

#validation transform and dataloader
validation_transform = transforms.Compose([transforms.Resize((224,224)),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
test_dataset = dsets.ImageFolder(root=validationpath, transform=validation_transform)
validation_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=32, shuffle=False)

num_epochs = 30
num_classes = 3
batch_size = 32
learning_rate = 0.0001
print('there are {}  images in training set.'.format(len(train_dataset)))
print('there are {}  images in test set.'.format(len(test_dataset)))
print('there are {}  images in training set.'.format(len(train_loader)))
print('there are {}  images in training loader.'.format(len(validation_loader)))

In [None]:
import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'squeezenet1_0', pretrained=True)
# or
# model = torch.hub.load('pytorch/vision:v0.6.0', 'squeezenet1_1', pretrained=True)
model.eval()

In [None]:
model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))

In [None]:
model.eval()

In [None]:
from torchsummary import summary

summary(model.cuda(), (3, 224, 224))

In [None]:
model = model.to (device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss().to (device)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

In [None]:
# Train the model
print(len(train_loader))
total_step = len(train_loader)
loss_list = []
acc_list = []
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Run the forward pass
        images=images.to (device)
        labels=labels.to (device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss_list.append(loss.item())

        # Backprop and perform Adam optimisation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track the accuracy
        total = labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        acc_list.append(correct / total)
        print("Training")
    #if (i + 1) % 100 == 0:
        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.4f}%'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item(),(correct / total) * 100))
PATH = '/content/drive/My Drive/--.pth'
torch.save(model.state_dict(), PATH)          
                    

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in validation_loader:
        images=images.to (device)
        labels=labels.to (device)
        outputs = model(images).to(device)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the  test images: {} %'.format((correct / total) * 100))