# ResNet with pytorch

- load data
- build model
- train and test
- transfer learning

In [3]:
import torch
import torch.nn as nn
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)
        
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[1][0]*0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show

In [4]:
import  os, glob
import  random, csv
from    torch.utils.data import Dataset, DataLoader
from    torchvision import transforms
from    PIL import Image

class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize
        
        self.name2label = {}
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            
            self.name2label[name] = len(self.name2label.keys())
        
        self.images, self.labels = self.load_csv('images.csv')
        
        if mode == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode == 'val':
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        elif mode == 'test':
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
    
    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file:', filename)
                
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img,label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels
    def __len__(self):
        return len(self.images)
    
    def denormalize(self, x_hat):
        
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
        return x
    
    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path= > image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
        img = tf(img)
        label = torch.tensor(label)
        return img,label
                

In [15]:
# build model
from    torch.nn import functional as F
class ResBlk(nn.Module):
    def __init__(self, inChannel, outChannel, stride=1):
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(outChannel)
        self.conv2 = nn.Conv2d(outChannel, outChannel, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(outChannel)
        
        self.extra = nn.Sequential()
        if outChannel!=inChannel:
            self.extra = nn.Sequential(
            nn.Conv2d(inChannel, outChannel, kernel_size=1, stride=stride),
            nn.BatchNorm2d(outChannel))
            
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        out = F.relu(out)
        return out

In [21]:
class ResNet18(nn.Module):
    def __init__(self, num_class):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Sequential(
        nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
        nn.BatchNorm2d(16))
        
        self.blk1 = ResBlk(16,32,stride=3)
        self.blk2 = ResBlk(32,64,stride=3)
        self.blk3 = ResBlk(64,128,stride=2)
        self.blk4 = ResBlk(128,256,stride=2)
        
        self.outlayer = nn.Linear(256*3*3, num_class)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)
        return x

In [35]:
import torch
from torch import optim, nn
import visdom
batchsz = 32
lr = 1e-3
epochs = 10

trainData = Pokemon('./data/pokemon', 224, mode='train')
valData = Pokemon('./data/pokemon', 224, mode='val')
testData = Pokemon('./data/pokemon', 224, mode='test')

trainLoader = DataLoader(trainData, batch_size=batchsz, shuffle=True,
                        num_workers=4)
valLoader = DataLoader(valData, batch_size=batchsz, shuffle=False,
                      num_workers=2)
testLoader = DataLoader(testData, batch_size=batchsz, shuffle=False,
                       num_workers=2)

model = ResNet18(5)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
viz = visdom.Visdom()
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
def evaluate(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred,y).sum().float().item()
    return correct/total

best_acc, best_epoch = 0,0
global_step = 0
for epoch in range(epochs):

        for step, (x,y) in enumerate(trainLoader):
            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            print(loss.item(), global_step)
            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, valLoader)
            if val_acc> best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best_acc:', best_acc,'best_epoch:', best_epoch)

model.load_state_dict(torch.load('best.mdl'))
test_acc = evalute(model, testLoader)
print('test_acc:',test_acc)
        




1.5182881355285645 0
2.103452205657959 1
2.068197727203369 2
1.9896528720855713 3
2.0654709339141846 4
0.7708738446235657 5
2.038616180419922 6
0.6397014856338501 7
0.9876345992088318 8
0.882544994354248 9
1.1049165725708008 10
0.6454030275344849 11
0.9028308391571045 12
0.7742918133735657 13
0.46328219771385193 14
0.5010548233985901 15
0.9106078147888184 16


  ' expressed in bytes should be converted ' +


0.26927223801612854 17
0.2981182336807251 18
0.9401915669441223 19
0.8364970088005066 20
0.8666273951530457 21


  ' expressed in bytes should be converted ' +
  ' expressed in bytes should be converted ' +


0.20236420631408691 22
0.5642319321632385 23
0.6349573135375977 24
0.40248703956604004 25
0.4329265058040619 26
0.6033247113227844 27
0.8497241735458374 28
0.3229106068611145 29
0.26927077770233154 30
0.19804391264915466 31
0.15070338547229767 32
0.6603262424468994 33
0.3132723569869995 34
0.3251098394393921 35
0.3701622784137726 36
0.5432785153388977 37
1.3515344858169556 38
0.9648956060409546 39
0.6611078381538391 40
0.5435547232627869 41
0.662379801273346 42
0.2745145559310913 43


  ' expressed in bytes should be converted ' +


0.44152846932411194 44
0.5021432042121887 45
0.28868770599365234 46


  ' expressed in bytes should be converted ' +


0.41890308260917664 47
0.17976997792720795 48
0.2890416383743286 49
0.19006359577178955 50
0.3342059850692749 51
0.34223130345344543 52
0.301426500082016 53
0.47043895721435547 54
0.4641607701778412 55
0.6855977773666382 56
0.4535382091999054 57
0.21803301572799683 58
0.19373226165771484 59
0.4124784469604492 60
0.39980462193489075 61
0.14749546349048615 62
0.2610394358634949 63
0.19463536143302917 64
0.37893179059028625 65


  ' expressed in bytes should be converted ' +


0.1933976709842682 66
0.2926105856895447 67
0.3652646541595459 68
0.20480839908123016 69
0.2366904616355896 70
0.11189503222703934 71
0.08500628173351288 72


  ' expressed in bytes should be converted ' +


0.8213914036750793 73
0.3684108555316925 74
0.1268329918384552 75
0.2117617428302765 76
0.2701195478439331 77
0.07751460373401642 78
0.15555301308631897 79
0.2562590539455414 80
0.16732099652290344 81
0.3893750011920929 82
0.10876034945249557 83
0.19097016751766205 84
0.15446113049983978 85
0.4122318923473358 86
0.3321864604949951 87


  ' expressed in bytes should be converted ' +


0.30623143911361694 88
0.19131679832935333 89
0.09633930027484894 90
0.2454959899187088 91
0.13489826023578644 92
0.20842348039150238 93
0.11459426581859589 94
0.14208580553531647 95
0.11801980435848236 96
0.0648353323340416 97
0.10164263844490051 98
0.1570172905921936 99
0.31430909037590027 100
0.30121222138404846 101


  ' expressed in bytes should be converted ' +


0.1856570690870285 102
0.22559285163879395 103
0.17752188444137573 104
0.1566973626613617 105
0.3598528504371643 106
0.21164388954639435 107
0.06415168195962906 108
0.115009605884552 109


  ' expressed in bytes should be converted ' +


0.3740781843662262 110
0.04966244474053383 111
0.16519340872764587 112
0.25447380542755127 113
0.02255323715507984 114
0.22495262324810028 115
0.02041688933968544 116
0.2654249966144562 117
0.14234080910682678 118


  ' expressed in bytes should be converted ' +


0.09890533983707428 119
0.1089281290769577 120
0.14453570544719696 121
0.06562875956296921 122
0.1481994092464447 123
0.14144107699394226 124
0.1803153157234192 125
0.10038936883211136 126
0.24415309727191925 127
0.022119730710983276 128
0.3638387620449066 129
0.4302480220794678 130
0.26322466135025024 131


  ' expressed in bytes should be converted ' +


0.051636502146720886 132
0.1825401484966278 133
0.28560659289360046 134
0.025978052988648415 135
0.5876161456108093 136
0.11468280851840973 137
0.13202053308486938 138
0.2746155261993408 139


  ' expressed in bytes should be converted ' +


0.20237937569618225 140
0.05579480156302452 141
0.1419968605041504 142
0.044787876307964325 143
0.13735675811767578 144
0.07284069806337357 145
0.06625097990036011 146
0.48213109374046326 147
0.21015653014183044 148
0.5998139977455139 149
0.07734183967113495 150
0.13093870878219604 151
0.0682600662112236 152
0.21881122887134552 153


  ' expressed in bytes should be converted ' +


0.3889780044555664 154
0.23757795989513397 155
0.13890765607357025 156
0.2523658871650696 157
0.11367856711149216 158
0.11287233233451843 159
0.5022757053375244 160
0.022356029599905014 161
0.045778051018714905 162


  ' expressed in bytes should be converted ' +


0.35135412216186523 163
0.13932223618030548 164
0.03947664052248001 165
0.429040789604187 166
0.5030101537704468 167
0.1739075779914856 168
0.26843610405921936 169
0.1280399113893509 170
0.19253885746002197 171
0.05064299330115318 172
0.15613166987895966 173
0.1485508233308792 174
0.12354399263858795 175


  ' expressed in bytes should be converted ' +
  ' expressed in bytes should be converted ' +


0.23393459618091583 176
0.21928325295448303 177
0.16773898899555206 178
0.14287392795085907 179
0.21804362535476685 180
0.10113179683685303 181
0.005801535211503506 182
0.045602116733789444 183
0.14524514973163605 184
0.017220256850123405 185
0.031919315457344055 186
0.11641865223646164 187
0.0707625225186348 188
0.39153528213500977 189
0.8241366147994995 190
0.07546449452638626 191
0.049078069627285004 192
0.15445783734321594 193
0.21944354474544525 194
0.02491134963929653 195
0.027529647573828697 196
0.03903251141309738 197


  ' expressed in bytes should be converted ' +


0.11700162291526794 198
0.1360456794500351 199
0.07049642503261566 200
0.06530829519033432 201
0.13503466546535492 202
0.13573579490184784 203
0.3640984296798706 204
0.20868128538131714 205
0.06068388372659683 206
0.4632214307785034 207
0.10328789055347443 208
0.4642248749732971 209


  ' expressed in bytes should be converted ' +


0.33785462379455566 210
0.11723791062831879 211
0.02525440603494644 212
0.11242695152759552 213
0.04972611367702484 214
0.22670121490955353 215
1.3936501741409302 216
0.1763879805803299 217
0.1639140397310257 218
0.06320852041244507 219


  ' expressed in bytes should be converted ' +


best_acc: 0.8755364806866953 best_epoch: 8
test_acc: 0.8497854077253219


In [42]:
resize = 224
tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path= > image data
            transforms.Resize((int(resize*1.25), int(resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
img = tf('./data/pokemon/bulbasaur/00000001.png')

In [64]:
from torch.autograd import Variable
pred = model(img.view(1,3,224,224)).argmax(dim=1)
int(pred.numpy())

0

In [65]:
trainData.name2label

{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}