In [1]:
import os
import random
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

In [2]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(123)
device = torch.device("cuda" if use_cuda else "cpu")
print('Device used:', device)

Device used: cuda


In [3]:
def load_data(img_path, label_path):
    train_image = sorted(glob.glob(os.path.join(img_path, '*.jpg')))
    train_label = pd.read_csv(label_path)
    train_label = train_label.iloc[:,1].values.tolist()
    
    train_data = list(zip(train_image, train_label))
    random.shuffle(train_data)
    
    train_set = train_data[:28800]
    valid_set = train_data[28800:]
    
    return train_set, valid_set

In [4]:
class hw3_dataset(Dataset):
    
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = Image.open(self.data[idx][0]).convert('RGB')
        img = self.transform(img)
        label = self.data[idx][1]
        return img, label

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding=2),
            nn.Dropout2d(0.4),
            nn.LeakyReLU(negative_slope=0.05),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),     
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(negative_slope=0.05),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),            
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3,padding=1),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(negative_slope=0.05),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3,padding=1),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(negative_slope=0.05),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            
        )
        self.fc = nn.Sequential(
            nn.Linear(3*3*128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 7)
        )

    def forward(self, x):
        #image size (48,48)
        x = self.conv1(x) #(24,24)
        x = self.conv2(x) #(12,12)
        x = self.conv3(x) #(6,6)
        x = self.conv4(x) #(3,3)
        x = x.view(-1, 3*3*128)
        x = self.fc(x)
        return x

In [6]:
if __name__ == '__main__':
    use_gpu = torch.cuda.is_available()
    print(use_gpu)
    train_set, valid_set = load_data('./train_img/', 'train.csv')

    #transform to tensor, data augmentation
    
    transform = transforms.Compose([
    #transforms.RandomAffine(15, translate=(0.1,0.1), scale=(0.9,1.1), shear=10, fillcolor=0),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    #transforms.Normalize([mean], [std], inplace=False)
    ])
    
    train_dataset = hw3_dataset(train_set,transform)
    print(len(train_dataset))
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

    valid_dataset = hw3_dataset(valid_set,transform)
    print(len(valid_dataset))
    valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False)

    model = Net()
    if use_gpu:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()

    num_epoch = 600
    with open("loss_acc.csv","w") as f:
        for epoch in range(num_epoch):
            model.train()
            train_loss = []
            train_acc = []
            for idx, (img, label) in enumerate(train_loader):

                if use_gpu:
                    img = img.cuda()
                    label = label.cuda()
                optimizer.zero_grad()
                output = model(img)
                loss = loss_fn(output, label)
                loss.backward()
                optimizer.step()

                predict = torch.max(output, 1)[1]
                acc = np.mean((label == predict).cpu().numpy())
                train_acc.append(acc)
                train_loss.append(loss.item())
            print("{},{:.4f},{:.4f}".format(epoch+1, np.mean(train_loss), np.mean(train_acc)))
            #print("？{:.6f},{:.6f}".format(np.mean(train_loss), np.mean(train_acc)) ,file = f)

            model.eval()
            with torch.no_grad():
                valid_loss = []
                valid_acc = []
                for idx, (img, label) in enumerate(valid_loader):
                    if use_gpu:

                        img = img.cuda()
                        label = label.cuda()
                    output = model(img)
                    loss = loss_fn(output, label)
                    predict = torch.max(output, 1)[1]
                    acc = np.mean((label == predict).cpu().numpy())
                    valid_loss.append(loss.item())
                    valid_acc.append(acc)
                print("{},{:.4f},{:.4f}".format(epoch+1, np.mean(valid_loss), np.mean(valid_acc)))
                #print("{:.6f},{:.6f}".format( np.mean(valid_loss),np.mean(valid_acc)) ,file = f)
            if np.mean(train_acc) > 0.9 and epoch%50==0:
                checkpoint_path = './model_3channel/model_{}.pth'.format(epoch+1) 
                torch.save(model.state_dict(), checkpoint_path)
                #print('model saved to %s' % checkpoint_path)
            
    

    #finish test code

    

True
28800
88
1,1.8501,0.2704
1,1.5952,0.3636
2,1.6001,0.3763
2,1.4540,0.4545
3,1.4854,0.4257
3,1.4485,0.4205
4,1.4098,0.4557
4,1.3499,0.4432
5,1.3532,0.4803
5,1.3391,0.4773
6,1.3002,0.5010
6,1.2578,0.5568
7,1.2634,0.5161
7,1.2183,0.5114
8,1.2307,0.5316
8,1.1905,0.5000
9,1.1982,0.5409
9,1.1972,0.5114
10,1.1643,0.5556
10,1.1604,0.5455
11,1.1398,0.5706
11,1.1045,0.5568
12,1.1088,0.5793
12,1.0091,0.6591
13,1.0835,0.5898
13,1.0623,0.5795
14,1.0554,0.5997
14,1.0691,0.5682
15,1.0404,0.6053
15,1.1057,0.5909
16,1.0140,0.6185
16,1.1369,0.5795
17,0.9906,0.6272
17,1.0205,0.6023
18,0.9667,0.6361
18,1.0655,0.5909
19,0.9514,0.6401
19,1.0097,0.5795
20,0.9348,0.6459
20,1.0902,0.5682
21,0.9089,0.6558
21,1.0742,0.5909
22,0.8914,0.6648
22,1.0832,0.5795
23,0.8684,0.6738
23,1.0431,0.6591
24,0.8501,0.6817
24,1.1062,0.5795
25,0.8329,0.6894
25,1.1500,0.6250
26,0.8180,0.6920
26,1.1048,0.6136
27,0.8075,0.7000
27,1.1345,0.6023
28,0.7810,0.7101
28,1.0829,0.6364
29,0.7637,0.7143
29,1.1928,0.5795
30,0.7549,0.7200
3

235,0.1885,0.9325
235,1.9566,0.6250
236,0.1935,0.9310
236,2.0122,0.6250
237,0.1951,0.9317
237,1.9185,0.6136
238,0.1937,0.9319
238,2.0555,0.6136
239,0.1919,0.9318
239,2.1153,0.6136
240,0.1952,0.9294
240,2.1032,0.5909
241,0.1912,0.9314
241,2.0741,0.6023
242,0.1943,0.9312
242,2.0321,0.5909
243,0.2011,0.9292
243,1.8572,0.5682
244,0.1962,0.9300
244,1.9686,0.5909
245,0.1977,0.9308
245,2.0009,0.5909
246,0.1963,0.9301
246,2.0449,0.6023
247,0.1962,0.9318
247,2.1377,0.6136
248,0.1930,0.9323
248,2.2541,0.6136
249,0.1906,0.9323
249,2.0649,0.6023
250,0.1915,0.9331
250,2.0742,0.5682
251,0.1917,0.9326
251,2.1341,0.6136
252,0.1921,0.9325
252,2.0385,0.6023
253,0.1957,0.9319
253,2.1639,0.6023
254,0.1956,0.9303
254,2.0889,0.6250
255,0.1908,0.9331
255,2.1546,0.6023
256,0.1848,0.9351
256,2.0450,0.5909
257,0.1879,0.9317
257,2.0030,0.5909
258,0.1819,0.9345
258,1.9910,0.6250
259,0.1858,0.9340
259,2.0507,0.5795
260,0.1874,0.9341
260,2.1192,0.6023
261,0.1866,0.9334
261,2.0691,0.5909
262,0.1842,0.9347
262,2.1092

463,0.1402,0.9513
463,2.2015,0.5909
464,0.1372,0.9522
464,2.3183,0.6250
465,0.1366,0.9519
465,2.1632,0.6136
466,0.1324,0.9528
466,2.3179,0.6136
467,0.1381,0.9529
467,2.2753,0.5909
468,0.1386,0.9529
468,2.2286,0.5795
469,0.1358,0.9517
469,2.2706,0.5795
470,0.1428,0.9501
470,2.2793,0.5909
471,0.1448,0.9498
471,2.3514,0.5909
472,0.1385,0.9506
472,2.3831,0.5568
473,0.1300,0.9543
473,2.3313,0.6023
474,0.1440,0.9500
474,2.3702,0.6023
475,0.1367,0.9517
475,2.2500,0.5909
476,0.1425,0.9504
476,2.2152,0.5909
477,0.1401,0.9529
477,2.2746,0.5795
478,0.1350,0.9532
478,2.3121,0.6136
479,0.1385,0.9528
479,2.2382,0.6023
480,0.1347,0.9540
480,2.1279,0.6023
481,0.1416,0.9513
481,2.2485,0.6250
482,0.1363,0.9528
482,2.2259,0.6136
483,0.1492,0.9489
483,2.1839,0.6023
484,0.1355,0.9528
484,2.1544,0.5909
485,0.1417,0.9509
485,2.3879,0.6250
486,0.1373,0.9513
486,2.3427,0.6136
487,0.1438,0.9508
487,2.3180,0.6136
488,0.1370,0.9531
488,2.3420,0.6250
489,0.1388,0.9502
489,2.3650,0.6250
490,0.1376,0.9539
490,2.3996

In [7]:
checkpoint_path = './model_3channel/model_601.pth'
torch.save(model.state_dict(), checkpoint_path)
print('model saved to %s' % checkpoint_path)

model saved to ./model_3channel/model_601.pth
