In [1]:
opt = {'Load_weights':True}
# Loads the pretarained weights

opt.update({'dataroot_1024':'data/Raise_dataset_init_crop_1024'})
opt.update({'dataroot_none':'data/Raise_dataset_init_crop_none'})
opt.update({'dataroot_512':'data/Raise_dataset_init_crop_512'})
# These addresses contain the test data


opt.update({'normalizeMean':[0.485, 0.456, 0.406]})
opt.update({'normalizeStd':[0.229, 0.224, 0.225]}) 
# Since we use ImageNet Pre-trained weights we use these values

opt.update({'batch_size':10}) 
opt.update({'sizeit':64}) # this is the actual batch size

opt.update({'device':"cuda"})
# change this to cpu if cuda or GPU not available

# because the code has to process about 25,000 images crawling over 20GB of data on GPU the code shall
# about 15 min to execute completely. Majority of time goes in processing the very high resolution images.
# Remove high resolution images for speed up

In [2]:
import random
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import torchvision.transforms.functional as F

import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as FF
import matplotlib.pyplot as plt
import itertools

In [3]:
def make_trns(size,opt):
    transform_it = []
    transform_it.append(transforms.CenterCrop(size))
    
    transform_it.append(transforms.ToTensor())
    transform_it.append(transforms.Normalize(opt['normalizeMean'], opt['normalizeStd']))
    trns=transforms.Compose(transform_it)
    return trns

trns128 = make_trns(128,opt)
trns256 = make_trns(256,opt)
trns512 = make_trns(512,opt)


### Data Loading Part

In [4]:
class get_data(Dataset):
    """Loads the Data."""

    def __init__(self, opt, transform=None, whichsize=None):
        
        if whichsize=='64':
            self.root_dir = opt['dataroot_512']
        elif whichsize=='128':
            self.root_dir = opt['dataroot_1024']
        else:
            self.root_dir = opt['dataroot_none']
        
        self.transform = transform
            
        tmp = sorted(os.walk(os.path.join(self.root_dir,'90_06/test')))
        self.train1files = sorted(tmp[0][2])

        tmp = sorted(os.walk(os.path.join(self.root_dir,'90_08/test')))
        self.train2files = sorted(tmp[0][2])

        tmp = sorted(os.walk(os.path.join(self.root_dir,'90_1/test')))
        self.train3files = sorted(tmp[0][2])

        tmp = sorted(os.walk(os.path.join(self.root_dir,'90_12/test')))
        self.train4files = sorted(tmp[0][2])

        tmp = sorted(os.walk(os.path.join(self.root_dir,'90_14/test')))
        self.train5files = sorted(tmp[0][2])
                

    def __len__(self):
        return min(len(self.train1files),len(self.train2files),len(self.train3files),
                   len(self.train4files),len(self.train5files))

    def __getitem__(self, idx):
            
        
        path = os.path.join(self.root_dir,'90_06/test',self.train1files[idx])
        image1 = Image.open(path).convert('RGB')

        path = os.path.join(self.root_dir,'90_08/test',self.train2files[idx])
        image2 = Image.open(path).convert('RGB')

        path = os.path.join(self.root_dir,'90_1/test',self.train3files[idx])
        image3 = Image.open(path).convert('RGB')

        path = os.path.join(self.root_dir,'90_12/test',self.train4files[idx])
        image4 = Image.open(path).convert('RGB')

        path = os.path.join(self.root_dir,'90_14/test',self.train5files[idx])
        image5 = Image.open(path).convert('RGB')
        
        image1 = self.transform(image1)
        image2 = self.transform(image2)
        image3 = self.transform(image3)
        image4 = self.transform(image4)
        image5 = self.transform(image5)
        
        
        return {'image1': image1, 'image2': image2, 'image3': image3, 'image4': image4, 'image5': image5, 
                'label1':0, 'label2':1, 'label3':2, 'label4':3, 'label5':4 }

In [5]:


data_test64 = get_data(opt, trns128, '64') # base resolution: 512, patch size: 128

opt.update({'dataloader_test512v128':DataLoader(data_test64, batch_size=opt['batch_size'],
                       shuffle=True, num_workers=1, pin_memory=False)})

data_test64 = get_data(opt, trns128, '128') # base resolution: 1024, patch size: 128

opt.update({'dataloader_test1024v128':DataLoader(data_test64, batch_size=opt['batch_size'],
                       shuffle=True, num_workers=1, pin_memory=False)})

data_test64 = get_data(opt, trns256, '128') # base resolution: 1024, patch size: 256

opt.update({'dataloader_test1024v256':DataLoader(data_test64, batch_size=opt['batch_size'],
                       shuffle=True, num_workers=1, pin_memory=False)})

data_test64 = get_data(opt, trns256, '512') # base resolution: >1024, patch size: 256

opt.update({'dataloader_test3000v256':DataLoader(data_test64, batch_size=opt['batch_size'],
                       shuffle=True, num_workers=1, pin_memory=False)})

data_test64 = get_data(opt, trns512, '512') # base resolution: >1024, patch size: 512

opt.update({'dataloader_test3000v512':DataLoader(data_test64, batch_size=opt['batch_size'],
                       shuffle=True, num_workers=1, pin_memory=False)})

### Model files

In [6]:
class my_resnet(nn.Module):
    def __init__(self):
        super(my_resnet, self).__init__()
        
        model_ft = models.resnet18(pretrained=True)        
        self.feat1 = nn.Sequential(*list(model_ft.children())[:6])
        self.feat2 = nn.Sequential(*list(model_ft.children())[6:8])
        
        self.rel = nn.ReLU(inplace=True)
        
        self.fc = nn.Linear(512, 5)
        
        self.conv = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=True)
        self.bn = nn.BatchNorm2d(128)
        
        self.roi1 = nn.MaxPool2d(2)
        
        
    def forward(self,out):
        
        out = self.bn(self.conv(self.feat1(out)))
        if out.size()[2]==8:
            out = self.roi1(out)
        if out.size()[2]==16:
            out = self.roi1(self.conv(out))
        if out.size()[2]==32:
            out = self.roi1(self.conv(self.conv(out)))
            
        out = self.feat2(out)
        out = out.view(-1,512)
        out = self.fc(out)
        
        return out

In [7]:
class common_functions():
    
    def __init__(self, opt):
        
        self.opt = opt
        
        
        self.best_test_acc = 0
        self.lowest_train_err = 100
        self.device = torch.device(self.opt['device'])
        self.test_acc = 0
        self.test_acc = 0
        self.count_save = 0
        self.step_count = 0
        
        
        model = my_resnet()
        self.model = model.to(self.device)
        print(self.model)
        print(next(self.model.parameters()).is_cuda)
          
        
        if opt['Load_weights']:
            checkpoint = torch.load('BestModel')
            self.model.load_state_dict(checkpoint['model'])
        
        

        
    def find_acc(self, loader):
        
        correct1 = 0
        total1 = 0
        correct2 = 0
        total2 = 0
        correct3 = 0
        total3 = 0
        correct4 = 0
        total4 = 0
        correct5 = 0
        total5 = 0
        
        
        
        self.model.eval()        
        
        with torch.no_grad():
            for i,images in enumerate(loader):
                
                images['image1'] = images['image1'].to(self.device)
                images['image2'] = images['image2'].to(self.device)
                images['image3'] = images['image3'].to(self.device)
                images['image4'] = images['image4'].to(self.device)
                images['image5'] = images['image5'].to(self.device)
                
                images['label1'] = images['label1'].to(self.device)
                images['label2'] = images['label2'].to(self.device)
                images['label3'] = images['label3'].to(self.device)
                images['label4'] = images['label4'].to(self.device)
                images['label5'] = images['label5'].to(self.device)
                
                
                out1 = self.model(images['image1'])
                
                out2 = self.model(images['image2'])
                
                out3 = self.model(images['image3'])
                
                out4 = self.model(images['image4'])
                
                out5 = self.model(images['image5'])
                
                               
                _, predicted1 = torch.max(out1.data, 1)
                _, predicted2 = torch.max(out2.data, 1)
                _, predicted3 = torch.max(out3.data, 1)
                _, predicted4 = torch.max(out4.data, 1)
                _, predicted5 = torch.max(out5.data, 1)
                
                total1 += len(images['label1'])
                #print(type(totalA))
                total2 += len(images['label2'])
                #print(totalB)
                total3 += len(images['label3'])
                #print(type(totalA))
                total4 += len(images['label4'])
                #print(totalB)
                total5 += len(images['label5'])
                #print(totalB)
                
                correct1 += (predicted1 == images['label1']).sum().item()
                # this just a int not even tensor
                correct2 += (predicted2 == images['label2']).sum().item()
                correct3 += (predicted3 == images['label3']).sum().item()
                # this just a int not even tensor
                correct4 += (predicted4 == images['label4']).sum().item()
                correct5 += (predicted5 == images['label5']).sum().item()
                
        #print('end of enumeration')
        test_acc1 = 100 * float(correct1) / float(total1)
        
        #print(test_accA.type())
        test_acc2 = 100 * float(correct2) / float(total2)
        
        test_acc3 = 100 * float(correct3) / float(total3)
        
        #print(test_accA.type())
        test_acc4 = 100 * float(correct4) / float(total4)
        
        test_acc5 = 100 * float(correct5) / float(total5)
        
        test_acc = (test_acc1 + test_acc2+ test_acc3+ test_acc4 + test_acc5)/5
        
        
        return total1, total2, total3, total4, total5, correct1, correct2, correct3, correct4, correct5, test_acc
        
        
    
                
    
    def optimize_parameters(self):
            
        
            
        loader = self.opt['dataloader_test512v128']
        total1, total2, total3, total4, total5, correct1, correct2, correct3, correct4, correct5, test_acc = self.find_acc(loader)
        

        print('Image size 512x512 with patch size 128x128')

        print('Resampling Factor 0.6:{}/{} 0.8:{}/{} 1:{}/{} 1.2:{}/{} 1.4:{}/{}'.format(correct1, total1,
                                                                  correct2, total2,
                                                                  correct3, total3,
                                                                  correct4, total4,
                                                                  correct5, total5))

        ###########################################

        loader = self.opt['dataloader_test1024v128']
        total1, total2, total3, total4, total5, correct1, correct2, correct3, correct4, correct5, test_acc = self.find_acc(loader)            

        

        print('\n Image size 1024x1024 with patch size 128x128')
        

        print('Resampling Factor 0.6:{}/{} 0.8:{}/{} 1:{}/{} 1.2:{}/{} 1.4:{}/{}'.format(correct1, total1,
                                                                  correct2, total2,
                                                                  correct3, total3,
                                                                  correct4, total4,
                                                                  correct5, total5))

        ###########################################

        loader = self.opt['dataloader_test1024v256']
        total1, total2, total3, total4, total5, correct1, correct2, correct3, correct4, correct5, test_acc = self.find_acc(loader)            


        print('\n Image size 1024x1024 with patc size 256x256')
        

        print('Resampling Factor 0.6:{}/{} 0.8:{}/{} 1:{}/{} 1.2:{}/{} 1.4:{}/{}'.format(correct1, total1,
                                                                  correct2, total2,
                                                                  correct3, total3,
                                                                  correct4, total4,
                                                                  correct5, total5))

        ##############################

        loader = self.opt['dataloader_test3000v256']
        total1, total2, total3, total4, total5, correct1, correct2, correct3, correct4, correct5, test_acc = self.find_acc(loader)            

        print('\n Image sizes greater than 1024x1024 with patch size 256x256')
        

        print('Resampling Factor 0.6:{}/{} 0.8:{}/{} 1:{}/{} 1.2:{}/{} 1.4:{}/{}'.format(correct1, total1,
                                                                  correct2, total2,
                                                                  correct3, total3,
                                                                  correct4, total4,
                                                                  correct5, total5))
        ##############################

        loader = self.opt['dataloader_test3000v512']
        total1, total2, total3, total4, total5, correct1, correct2, correct3, correct4, correct5, test_acc = self.find_acc(loader)           

        print('\n Image sizes greater than 1024x1024 with patch size 512x512')

        print('Resampling Factor 0.6:{}/{} 0.8:{}/{} 1:{}/{} 1.2:{}/{} 1.4:{}/{}'.format(correct1, total1,
                                                                  correct2, total2,
                                                                  correct3, total3,
                                                                  correct4, total4,
                                                                  correct5, total5))

        
        
       
            
            
        
        

#### Iterating over the dataset

In [8]:
gan_model = common_functions(opt)
gan_model.optimize_parameters()

my_resnet(
  (feat1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [None]:
### So with base-size vs patch size selection criteria given in paper [Refer table titled 'Patch sizes for the proposed IPN and BN.'] the accuracy is

Image size 512x512 with patch size 128x128
Resampling Factor 0.6:968/1001 0.8:976/1001 1:961/1001 1.2:935/1001 1.4:936/1001

 Image size 1024x1024 with patch size 128x128
Resampling Factor 0.6:979/1001

 Image size 1024x1024 with patc size 256x256
Resampling Factor 0.8:989/1001 1:995/1001 1.2:997/1001

 Image sizes greater than 1024x1024 with patch size 256x256
Resampling Factor 0.6:991/1001 0.8:959/1001 1:992/1001

 Image sizes greater than 1024x1024 with patch size 512x512
Resampling Factor 1:994/1001 1.2:963/1001 1.4:999/1001