In [1]:
from __future__ import print_function
from __future__ import division

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import operator

import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR100
import torchvision.transforms as tt

import os

In [2]:
# initialize a pretrained model (imageNet)
data_dir = "./data/hymenoptera_data"
model_name = "resnet" #choosing alexnet since it is "relatively" easy to train
# model_name = "squeezenet" # changed to squeezeNet since it gets same acc as alex but smaller
num_classes = 100 # in cifar100

batch_size = 8

num_epochs = 15

feature_extract = False #set to false so we can finetune entire model

In [3]:
# transforms
stats = ((0.5074,0.4867,0.4411),(0.2011,0.1987,0.2025))
train_transform = tt.Compose([
    tt.RandomHorizontalFlip(),
    tt.RandomCrop(32,padding=4,padding_mode="reflect"),
    tt.ToTensor(),
    tt.Normalize(*stats)
])

test_transform = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(*stats)
])


In [4]:
# load data
train_data = CIFAR100(download=True,root="./data",transform=train_transform)
test_data = CIFAR100(root="./data",train=False,transform=test_transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)


Files already downloaded and verified


In [5]:
super_classes = [1,2,17]
sub_classes = [5,6,9,10,11,12,85,86,87]

In [6]:
from tqdm import tqdm
train_2 = []
test_2 = []
for i in tqdm(train_data):
    if i[1] in sub_classes:
        train_2.append(i)
    
for i in tqdm(test_data):
    if i[1] in sub_classes:
        test_2.append(i)

100%|██████████| 50000/50000 [05:44<00:00, 145.12it/s]
100%|██████████| 10000/10000 [00:27<00:00, 364.51it/s]


In [15]:
class CIFAR9_train(Dataset):
    def __init__(self):
        # data loading
        train_x = list(zip(*train_2))[0]
        train_y = list(zip(*train_2))[1]
        
        train_y = np.array(train_y)
        
        self.x = torch.stack(train_x)
        
        self.y = torch.from_numpy(train_y)  
        
        
        self.n_samples = train_y.shape[0]
        
    def __getitem__(self, index):
        #dataset[0]
        return self.x[index], self.y[index]

    def __len__(self):
        # len(dataset)
        return self.n_samples
        
class CIFAR9_test(Dataset):
    def __init__(self):
        # data loading
        test_x = list(zip(*test_2))[0]
        test_y = list(zip(*test_2))[1]
        
        test_y = np.array(test_y)
        
        self.x = torch.stack(test_x)
        self.y = torch.from_numpy(test_y)    
        
        self.n_samples = test_y.shape[0]
        
    def __getitem__(self, index):
        #dataset[0]
        return self.x[index], self.y[index]

    def __len__(self):
        # len(dataset)
        return self.n_samples
    

In [16]:

trainData = CIFAR9_train()
testData = CIFAR9_test()


trainloader = torch.utils.data.DataLoader(trainData, batch_size=batch_size, shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(testData, batch_size=batch_size, shuffle=False, num_workers=2)

In [17]:
super_classes_10 = [1,2,15,17] #15-> reptiles
sub_classes_10 = [5,6,9,10,11,12,75,85,86,87,] # 75-> crocodile
from tqdm import tqdm
train_cifar10 = []
test_cifar10 = []
for i in tqdm(train_data):
    if i[1] in sub_classes_10:
        train_cifar10.append(i)
    
for i in tqdm(test_data):
    if i[1] in sub_classes_10:
        test_cifar10.append(i)
        
class CIFAR10_train(Dataset):
    def __init__(self):
        # data loading
        train_x = list(zip(*train_cifar10))[0]
        train_y = list(zip(*train_cifar10))[1]
        
        train_y = np.array(train_y)
        
        self.x = torch.stack(train_x)
        self.y = torch.from_numpy(train_y)  
        
        
        self.n_samples = train_y.shape[0]
        
    def __getitem__(self, index):
        #dataset[0]
        return self.x[index], self.y[index]

    def __len__(self):
        # len(dataset)
        return self.n_samples
        
class CIFAR10_test(Dataset):
    def __init__(self):
        # data loading
        test_x = list(zip(*test_cifar10))[0]
        test_y = list(zip(*test_cifar10))[1]
        
        test_y = np.array(test_y)
        
        self.x = torch.stack(test_x)
        self.y = torch.from_numpy(test_y)    
        
        self.n_samples = test_y.shape[0]
        
    def __getitem__(self, index):
        #dataset[0]
        return self.x[index], self.y[index]

    def __len__(self):
        # len(dataset)
        return self.n_samples
    


  1%|          | 534/50000 [00:03<05:26, 151.70it/s]


KeyboardInterrupt: 