In [None]:
# https://arxiv.org/pdf/1801.09573.pdf
# Datasets
# We have collected the 1000 images from the internet in 2 categories as muffin and Chihuahua
# including 200 images from Oxford pet animal dataset [7] as shown in Figure 3. Please note here, all
# resources as images from the internet is for research purpose only, we don’t own any of them. ImageNet
# [5] also includes 1750 Chihuahua and 1335 various type of muffin images already.

# http://image-net.org/synset?wnid=n02085620 : 1750 Chihuahua
# http://image-net.org/synset?wnid=n07690273 : 1335 muffins

In [None]:
# https://aviaryan.com/blog/gsoc/downloading-files-from-urls
# import requests
# def is_downloadable(url):
#     """
#     Does the url contain a downloadable resource
#     """
#     try:
#         h = requests.head(url, allow_redirects=True)
#     except:
#         return False
#     header = h.headers
#     content_type = header.get('content-type')
#     if content_type == None:
#         return False
#     if 'text' in content_type.lower():
#         return False
#     if 'html' in content_type.lower():
#         return False
#     return True

# file1 = open('chihuahua.txt', 'r')
# Lines = file1.readlines()
# count = 0
# for line in Lines:
#     print(line)
#     if(is_downloadable(line)):
#         count += 1
#         r = requests.get(line, allow_redirects=True)
#         open(f'data/chihuahua/{count}.jpg', 'wb').write(r.content)

In [None]:
# file1 = open('muffin.txt', 'r')
# Lines = file1.readlines()
# count = 0
# for line in Lines:
#     print(line)
#     if(is_downloadable(line)):
#         count += 1
#         r = requests.get(line, allow_redirects=True)
#         open(f'data/muffin/{count}.jpg', 'wb').write(r.content)

In [1]:
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
import time
import os
from copy import copy
from copy import deepcopy
import torch.nn.functional as F

# Allow augmentation transform for training set, no augementation for val/test set
# Normalize(mean, std, inplace=False)
# mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)

preprocess_augment = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])


full_train_dataset = torchvision.datasets.ImageFolder('data/imagenet')
print(full_train_dataset)
train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [1600, 1988-1600])
train_dataset.dataset = copy(full_train_dataset)
train_dataset.dataset.transform = preprocess_augment
val_dataset.dataset.transform = preprocess

# DataLoaders for the three datasets
BATCH_SIZE=4
NUM_WORKERS=2
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True , num_workers=NUM_WORKERS)
val_dataloader   = torch.utils.data.DataLoader(val_dataset  , batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS)



Dataset ImageFolder
    Number of datapoints: 1988
    Root location: data/train


In [2]:
# import matplotlib.pyplot as plt
import numpy as np
# # get some random training images
# dataiter = iter(train_dataloader)
# images, labels = dataiter.next()
# print(labels)
# toshow = torchvision.utils.make_grid(images)
# toshow = toshow / 2 + 0.5     # unnormalize
# npimg = toshow.numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
classes = np.array(['chihuahua','muffin'])
# print(labels, classes[labels])

In [None]:
dataloaders = {'train': train_dataloader, 'val': val_dataloader}

from myNetwork.myResNet import ResNet
from trainer import trainer

def SEResNet18(num_classes = 10):
    return ResNet(ResNet._BLOCK_SEBASIC, [2, 2, 2, 2], num_classes)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
model = SEResNet18()
model.load_state_dict(torch.load('result-20210129-104338/seresnet18_adam_0.01.pth'))
model.classifier[2] = nn.Linear(512,2)
model.eval()


In [None]:
model.to(device)
criterion = nn.CrossEntropyLoss()
params_to_update = model.parameters()
# Now we'll use Adam optimization
optimizer = optim.Adam(params_to_update, lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=5)
t = trainer(device,criterion, optimizer,scheduler)
model = t.train(model, dataloaders, num_epochs=60, weights_name='seresnet18_chihuahua_muffin_adam_0.01')

In [3]:
test_dataset = torchvision.datasets.ImageFolder('data/test', transform=preprocess)
print(test_dataset)
test_dataloader   = torch.utils.data.DataLoader(test_dataset  , batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS)

Dataset ImageFolder
    Number of datapoints: 16
    Root location: data/test
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=PIL.Image.BILINEAR)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
           )


In [5]:
from myNetwork.myResNet import ResNet
from trainer import trainer

def SEResNet18(num_classes = 10):
    return ResNet(ResNet._BLOCK_SEBASIC, [2, 2, 2, 2], num_classes)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
model = SEResNet18()
model.classifier[2] = nn.Linear(512,2)
model.load_state_dict(torch.load('result-20210130-150221/seresnet18_chihuahua_muffin_adam_0.01.pth'))
model.eval()
model.to(device)
criterion = nn.CrossEntropyLoss()
params_to_update = model.parameters()
# Now we'll use Adam optimization
optimizer = optim.Adam(params_to_update, lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=5)
t = trainer(device,criterion, optimizer,scheduler)
t.test(model,test_dataloader, classes)

cuda:1
===== Testing =====
Accuracy of the network on the 16 test images: 87 %
Accuracy of chihuahua : 100 %
Accuracy of muffin : 75 %


In [None]:
print(len(classes))