In [4]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import re
import torch
from glob import glob
from PIL import Image
import torchvision.transforms as transforms

In [5]:
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, folderName, transform=None):
        self.transform = transform
        self.data = []
        self.label = []

        for img_path in sorted(glob(folderName + '/*.jpg')):
            try:
                # Get classIdx by parsing image path
                class_idx = int(re.findall(re.compile(r'\d+'), img_path)[1])
            except:
                # if inference mode (there's no answer), class_idx default 0
                class_idx = 0

            image = Image.open(img_path)
            # Get File Descriptor
            image_fp = image.fp
            image.load()
            # Close File Descriptor (or it'll reach OPEN_MAX)
            image_fp.close()

            self.data.append(image)
            self.label.append(class_idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self.data[idx]
        if self.transform:
            image = self.transform(image)
        return image, self.label[idx]


trainTransform = transforms.Compose([
    transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
testTransform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

def get_dataloader(mode='training', batch_size=32):

    assert mode in ['training', 'testing', 'validation']

    dataset = MyDataset(
        f'../HW_3/data/{mode}',
        transform=trainTransform if mode == 'training' else testTransform)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'training'))

    return dataloader


In [6]:
train_dataloader = get_dataloader('training', batch_size=16)
valid_dataloader = get_dataloader('validation', batch_size=16)

In [7]:
'''
簡單上來說就是讓已經做得很好的大model們去告訴小model"如何"學習。 而我們如何做到這件事情呢? 就是利用大model預測的logits給小model當作標準就可以了。

為甚麼這會work?
1 例如當data不是很乾淨的時候，對一般的model來說他是個noise，只會干擾學習。透過去學習其他大model預測的logits會比較好。
2 label和label之間可能有關連，這可以引導小model去學習。例如數字8可能就和6,9,0有關係。
3 弱化已經學習不錯的target(?)，避免讓其gradient干擾其他還沒學好的task。

'''
#TODO 数据集读取 和训练  预测 保存

'\n簡單上來說就是讓已經做得很好的大model們去告訴小model"如何"學習。 而我們如何做到這件事情呢? 就是利用大model預測的logits給小model當作標準就可以了。\n\n為甚麼這會work?\n1 例如當data不是很乾淨的時候，對一般的model來說他是個noise，只會干擾學習。透過去學習其他大model預測的logits會比較好。\n2 label和label之間可能有關連，這可以引導小model去學習。例如數字8可能就和6,9,0有關係。\n3 弱化已經學習不錯的target(?)，避免讓其gradient干擾其他還沒學好的task。\n\n'

In [8]:
def loss_fn_kd(output,labels,teacher_output,T=20,alpha=0.5):
    hard_loss = F.cross_entropy(output,labels)*(1. -alpha)
    soft_loss = nn.KLDivLoss(reduction ='batchmean')(F.log_softmax(output/T,dim=1),
                                            F.softmax(teacher_output/T,dim=1))* (alpha * T * T)  
    return hard_loss +soft_loss

In [9]:
#data loader 和hw 3 一样
class StudentNet(nn.Module):
    '''
      在這個Net裡面，我們會使用Depthwise & Pointwise Convolution Layer來疊model。
      你會發現，將原本的Convolution Layer換成Dw & Pw後，Accuracy通常不會降很多。

      另外，取名為StudentNet是因為這個Model等會要做Knowledge Distillation。
    '''
    def __init__(self,base = 16,width_mult = 1):
        super(StudentNet,self).__init__()
        multiplier = [1,2,4,8,16,16,16,16]
        bandwidth =[base*m for m in multiplier]

        for i in range(3,7):
            bandwidth[i] = int(bandwidth[i]* width_mult)
        
        self.cnn = nn.Sequential(
            nn.Sequential(
            nn.Conv2d(3,bandwidth[0],3,1,1),
            nn.BatchNorm2d(bandwidth[0]),
            nn.ReLU6(),
            nn.MaxPool2d(2,2,0),
            ),
            nn.Sequential(
                nn.Conv2d(bandwidth[0],bandwidth[0],3,1,1),
                nn.BatchNorm2d(bandwidth[0]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[0],bandwidth[1],1),
                nn.MaxPool2d(2,2,0),
            ),
            nn.Sequential(
                nn.Conv2d(bandwidth[1], bandwidth[1], 3, 1, 1, groups=bandwidth[1]),
                nn.BatchNorm2d(bandwidth[1]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[1], bandwidth[2], 1),
                nn.MaxPool2d(2, 2, 0),
            ),
            nn.Sequential(
                nn.Conv2d(bandwidth[2], bandwidth[2], 3, 1, 1, groups=bandwidth[2]),
                nn.BatchNorm2d(bandwidth[2]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[2], bandwidth[3], 1),
                nn.MaxPool2d(2, 2, 0),
            ),
            #bandwidth 16
            nn.Sequential(
                nn.Conv2d(bandwidth[3], bandwidth[3], 3, 1, 1, groups=bandwidth[3]),
                nn.BatchNorm2d(bandwidth[3]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[3], bandwidth[4], 1),
            ),
            nn.Sequential(
                nn.Conv2d(bandwidth[4], bandwidth[4], 3, 1, 1, groups=bandwidth[4]),
                nn.BatchNorm2d(bandwidth[4]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[4], bandwidth[5], 1),
            ),
            nn.Sequential(
                nn.Conv2d(bandwidth[5], bandwidth[5], 3, 1, 1, groups=bandwidth[5]),
                nn.BatchNorm2d(bandwidth[5]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[5], bandwidth[6], 1),
            ),
            nn.Sequential(
                nn.Conv2d(bandwidth[6], bandwidth[6], 3, 1, 1, groups=bandwidth[6]),
                nn.BatchNorm2d(bandwidth[6]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[6], bandwidth[7], 1),
            ),
            # 這邊我們採用Global Average Pooling。
            # 如果輸入圖片大小不一樣的話，就會因為Global Average Pooling壓成一樣的形狀，這樣子接下來做FC就不會對不起來。
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Sequential(
            nn.Linear(bandwidth[7],11),
        )
    def forward(self, x):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        return self.fc(out)


In [10]:
teacher_net = models.resnet18(pretrained=False, num_classes=11).cuda()
student_net = StudentNet(base=16).cuda()

teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
#training 
def run_epoch(dataloader,update = True,alpha = 0.5):
    total_num,total_hit,total_loss = 0,0,0
    for now_step,batch_data in enumerate(dataloader):
        optimizer.zero_grad()
        inputs,hard_labels = batch_data
        inputs = inputs.to(device)
        hard_labels = torch.LongTensor(hard_labels).to(device)
        
        with torch.no_grad():
            soft_labels = teacher_net(inputs)
        if update:
            logits = student_net(inputs)
            loss = loss_fn_kd(logits,hard_labels,soft_labels,20,alpha)
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                logits = student_net(inputs)
                loss = loss_fn_kd(logits,hard_labels,soft_labels,20,alpha)
        total_hit += torch.sum(torch.argmax(logits,dim=1)==hard_labels).item()
        total_num += len(inputs)

        total_loss +=loss.item() *len(inputs)
    return total_loss /total_num, total_hit/total_num

teacher_net.eval()
now_best_acc = 0
for epoch in range(200):
    student_net.train()
    train_loss , train_acc = run_epoch(train_dataloader,update = True)
    student_net.eval()
    valid_loss,valid_acc = run_epoch(valid_dataloader,update =False)

    if valid_acc >now_best_acc:
        now_best_acc = valid_acc 
        torch.save(student_net.state_dict(), 'student_model.bin')
    print('epoch {:>3d}: train loss : {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc))


epoch   0: train loss : 15.2940, acc 0.3020 valid loss: 15.7432, acc 0.3609
epoch   1: train loss : 13.8030, acc 0.3813 valid loss: 14.4804, acc 0.4475
epoch   2: train loss : 12.8846, acc 0.4313 valid loss: 12.7580, acc 0.5055
epoch   3: train loss : 12.1045, acc 0.4700 valid loss: 11.4829, acc 0.5499
epoch   4: train loss : 11.4562, acc 0.5004 valid loss: 11.0343, acc 0.5510
epoch   5: train loss : 11.1576, acc 0.5214 valid loss: 10.6701, acc 0.5732
epoch   6: train loss : 10.4650, acc 0.5451 valid loss: 12.3416, acc 0.5612
epoch   7: train loss : 10.1771, acc 0.5545 valid loss: 10.8382, acc 0.5650
epoch   8: train loss : 9.7046, acc 0.5697 valid loss: 9.1572, acc 0.6353
epoch   9: train loss : 9.3010, acc 0.5895 valid loss: 10.1154, acc 0.5860
epoch  10: train loss : 9.2472, acc 0.5880 valid loss: 10.2205, acc 0.6023
epoch  11: train loss : 8.8082, acc 0.6050 valid loss: 8.5861, acc 0.6397
epoch  12: train loss : 8.5867, acc 0.6096 valid loss: 9.1077, acc 0.6589
epoch  13: train los

KeyboardInterrupt: 

In [None]:
#inference  预测