Refer https://www.learnpytorch.io for further concepts on PyTorch

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision

from torchvision import transforms

from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

In [2]:
print(torch.cuda.is_available())

True


In [3]:
#!python3 download.py

Downloading 1393 images
Error downloading http://www.off-n-trottin.com/images/Breyers%20and%20Kittens%20035.jpg
Error downloading http://www.arar93.dsl.pipex.com/mds975/Images/c_ronald_ginger_cat_01.jpg
Error downloading http://www.unknownhighway.com/images/uploads/littletigercat-12-20-07-small.jpg
Error downloading http://www.whitelightening.net/BuzzellTest/Creative/Tripp-TigerCat.jpg
Error downloading http://www1.istockphoto.com/file_thumbview_approve/2754709/2/istockphoto_2754709_white_tiger.jpg
Error downloading http://www.thebassethound.com/images/king1-sm.jpg
Error downloading http://www.salmonherder.com/silver031.jpg
Error downloading http://www.kenairiverhideaway.com/pix1/pat3.jpg
Error downloading http://photos.oregonlive.com/photogallery/f43aa33cdff2bbd5e0173ab7a9460f04.jpg
Error downloading http://image59.webshots.com/459/3/97/24/2362397240073428963uOVsAS_ph.jpg
Error downloading http://www.atmos.washington.edu/~mantua/images/silver2.gif
Error downloading http://www.alaskafi

In [13]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

### Image Transforms (Resize, Tensor Convertion, ImageNet mean, std)



In [14]:
img_transforms = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #Avoid exploding gradient problem
])

In [15]:
train_data_path = "./train/"
train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=img_transforms, is_valid_file=check_image)

In [16]:
val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)

In [17]:
test_data_path = "./test/"
test_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=img_transforms, is_valid_file=check_image)

In [18]:
batch_size = 64

In [19]:
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

## Neural Network

In [74]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet,self).__init__()
        self.fc1 = nn.Linear(12288, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50, 2)
        
    def forward(self, x):
        x = x.view(-1, 12288) # Convert to 1D Vector
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
    
simplenet = SimpleNet()

In [75]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

## Training

In [76]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
simplenet.to(device)

SimpleNet(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

In [87]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cuda"):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item()*inputs.size(0)
        training_loss /= len(train_loader.dataset)
    
        model.eval()
        num_correct = 0
        num_examples = 0
    
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets)
            valid_loss += loss.data.item()
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1],targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        
        valid_loss /= len(val_loader.dataset)
    
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss, valid_loss, num_correct / num_examples))

In [102]:
train(simplenet, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, test_data_loader, epochs=20, device=device)

  x = F.softmax(self.fc3(x))
  correct = torch.eq(torch.max(F.softmax(output), dim=1)[1],targets).view(-1)


Epoch: 0, Training Loss: 0.41, Validation Loss: 0.01, accuracy = 0.91
Epoch: 1, Training Loss: 0.41, Validation Loss: 0.01, accuracy = 0.91
Epoch: 2, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.91
Epoch: 3, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.91
Epoch: 4, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.91
Epoch: 5, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.91
Epoch: 6, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.91
Epoch: 7, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 8, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 9, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 10, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 11, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 12, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 13, Training Loss: 0.40, Validation Loss: 0.01, accuracy = 0.92
Epoch: 14, Train

## Training 

In [106]:
from PIL import Image

labels = ['cat', 'fish']

FILENAME = "./val/fish/silver24.jpg"

img = Image.open(FILENAME)
img = img_transforms(img).to(device)
img = torch.unsqueeze(img, 0)

simplenet.eval()
prediction = F.softmax(simplenet(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction])


fish


  x = F.softmax(self.fc3(x))


In [107]:
torch.save(simplenet, "./model/simplenet1")

In [109]:
simplenet = torch.load("./model/simplenet1")    

In [110]:
torch.save(simplenet.state_dict(), "./model/simplenet2")    
simplenet = SimpleNet()
simplenet_state_dict = torch.load("./model/simplenet2")
simplenet.load_state_dict(simplenet_state_dict)   

<All keys matched successfully>