In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torch.utils.data 
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFile

## DataLoader

In [2]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

def check_image(path):
    """Check if the image is valid or not"""
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [3]:
img_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [4]:
train_path = "./train/"
train_data = torchvision.datasets.ImageFolder(root=train_path, transform=img_transforms, is_valid_file=check_image)

val_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_path, transform=img_transforms, is_valid_file=check_image)

test_path = "./test/"
test_data = torchvision.datasets.ImageFolder(root=test_path, transform=img_transforms, is_valid_file=check_image)

In [20]:
batch_size = 64

In [21]:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

## Load Model

In [8]:
ResNet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\wassi/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:10<00:00, 10.2MB/s]


In [9]:
ResNet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
for name, param in ResNet.named_parameters():
    if ("bn" not in name):
        param.requires_grad = False

In [14]:
ResNet.fc = nn.Sequential(
    nn.Linear(ResNet.fc.in_features, 500),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(500, 2)
)

In [16]:
optimizer = optim.Adam(ResNet.parameters(), lr=0.0003)

In [18]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

ResNet.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Training

In [15]:
def train(model, optimizer, criterion, train_loader, val_loader, num_epochs, device = "cpu"):
    for epoch in range(num_epochs):
        model.train()
        trainning_loss = 0.0
        valid_loss = 0.0
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            trainning_loss += loss.item() * inputs.size(0)
        trainning_loss = trainning_loss / len(train_loader.dataset)
        
        num_correct = 0
        num_examples = 0
        model.eval()
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = criterion(output, targets)
            valid_loss += loss.item() * inputs.size(0)
            correct  =torch.eq(torch.max(F.softmax(output, dim=1),1)[1],targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        accuracy = num_correct / num_examples
        valid_loss = valid_loss / len(val_loader.dataset)
        
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, trainning_loss, valid_loss, accuracy))           

In [22]:
train(ResNet, optimizer, nn.CrossEntropyLoss(), train_loader, val_loader, 2, device)

Epoch: 0, Training Loss: 0.53, Validation Loss: 0.29, accuracy = 0.90
Epoch: 1, Training Loss: 0.29, Validation Loss: 0.29, accuracy = 0.87


## Making Presdictions

In [25]:
label = ['cat', 'fish']

img = Image.open('./test/cat/6170850_a262e64099.jpg')
img = img_transforms(img).to(device)
img = img.unsqueeze(0)

ResNet.eval()
prediction = F.softmax(ResNet(img), dim=1)
prediction = prediction.argmax()

print(label[prediction.item()])

cat


In [24]:
num_correct = 0
num_examples = 0
for batch in test_loader:
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)
    output = ResNet(inputs)
    correct = torch.eq(torch.max(F.softmax(output, dim=1),1)[1],targets)
    num_correct += torch.sum(correct).item()
    num_examples += correct.shape[0]

print('Accuracy on test data: {:.2f}'.format(num_correct / num_examples))

Accuracy on test data: 0.91
