In [1]:
#from zipfile import ZipFile
#zf = ZipFile('newFlowers.zip', 'r')
#zf.extractall()
#zf.close()

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms 
from tqdm.notebook import tqdm
import os
import numpy as np
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def get_mean_std(loader):
    # var[X] = E[X**2] - E[X]**2
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0

    for data, _ in tqdm(loader):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

# Function to Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):
    num_correct = 0 
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            new_y = torch.tensor([1 if (i ==0 or i==1 or i==3) else 0 for i in y.cpu().numpy()]).to(device)
            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == new_y).sum()
            num_samples += predictions.size(0)


    model.train()
    return num_correct/num_samples

In [4]:
#hyperparameter
lr = 0.001
batch_size = 128
input_size = (80,80)
num_classes = 2
input_channel = 3
epochs = 50

In [5]:
class FlowerDataset(Dataset):
    def __init__(self, root_dir, weights):
        self.root_dir = root_dir
        self.transform = transforms.Compose(
            [
            transforms.Resize((64,64)),
            transforms.ToTensor()
            ]
        )
        dataset = datasets.ImageFolder(root=root_dir, transform= self.transform)
        self.dataset = []
        self.total = len(dataset)
        self.weights_count = [0]*5
        for (data, label) in tqdm(dataset):
            if self.weights_count[label] < self.total*weights[label]/5:
                self.dataset.append((data, label))
                self.weights_count[label]+=1
        

    def __len__(self):
        return sum(self.weights_count)

    def __getitem__(self, index):
        (X,y) = self.dataset[index]
        return (X, y)

root_dir = "newFlowers"
train_dataset = FlowerDataset(root_dir,[0.41,0.41,1/4,0.41, 1])

train_set, _ = torch.utils.data.random_split(train_dataset, [int(0.8*len(train_dataset)),
                                                             len(train_dataset)-int(0.8*len(train_dataset))])


test_dataset = FlowerDataset(root_dir,[1,1,1,1,1])

_, test_set = torch.utils.data.random_split(test_dataset, [int(0.8*len(test_dataset)),
                                                           len(test_dataset)-int(0.8*len(test_dataset))])


HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))




In [6]:
class_weights = [1,1,1,1,1]
sample_weight = [0]*len(train_set)
loop = tqdm(enumerate(train_set), total= len(train_set), leave = False)

for idx, (data, label) in loop:
    sample_weight[idx] = class_weights[label]

sampler = WeightedRandomSampler(sample_weight, num_samples=len(train_set), replacement=True)
train_loader = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
test_loader = DataLoader(test_set, batch_size=batch_size)

HBox(children=(FloatProgress(value=0.0, max=528.0), HTML(value='')))

In [7]:

mean, std = get_mean_std(train_loader)

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [8]:
import torchvision
class CNN(nn.Module):
    def __init__(self, input_channel, num_classes):
        super(CNN, self).__init__()
        self.fc1 = nn.Linear(5,5)
        

    def forward(self, x):
        x = self.fc1(x)
        return x


# Simple Identity class that let's input pass without changes
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [9]:
# Load pretrain model & modify it
model1 = torchvision.models.alexnet(pretrained=True)

model1.classifier[-1] = nn.Linear(4096, 1)
model1.to(device)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /home/dam797/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth


HBox(children=(FloatProgress(value=0.0, max=244418560.0), HTML(value='')))




AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [None]:
# For Weighted Loss Function change the weights to the inverse ratio of number of samples

In [10]:
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1,1,1,1,1], dtype=torch.float32).to(device))
optimizer = optim.Adam(model1.parameters(), lr = lr)

In [11]:
loss_per_epoch = []
for epoch in range(1, epochs):
    loop = tqdm(enumerate(train_loader), leave= False)
    loss_sum = 0.0 
    count = 0
    for batch_idx, (data, targets) in loop:

        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model1(data)
        ####
        new_scores = torch.zeros(scores.shape[0], 5)

        new_targets = torch.zeros(scores.shape[0], 5)
        m= {0:1, 1:1, 3:1, 2:0, 4:0}
        for i, val in enumerate(scores):
            if m[targets[i].item()] ==1:

                new_scores[targets[i].item()] = val[0]
            else:
                new_scores[targets[i].item()] = 1-val[0]
                
        ######
        new_scores = new_scores.to(device)
        ######
        loss = criterion(new_scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()
        loss_sum+=loss
        count+=1
        # gradient descent or adam step
        optimizer.step()
       
        loop.set_description(f"Epoch [{epoch}/{epochs}]")
        loop.set_postfix(loss=torch.rand(1).item())
#     print(loss_sum/count)
    loss_per_epoch.append(loss_sum/count)

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

In [12]:
print(f"Accuracy on training set: {check_accuracy(test_loader, model1)*100:.2f}")

Accuracy on training set: 46.22
