Homework 8 - Network Compression
===

In [12]:
!pip install gdown
!pip install torchsummary

[0m

In [13]:
### This block is same as HW3 ###
# Import necessary packages.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.transforms as transforms
import torchvision.models as models

from PIL import Image
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder

# This is for the progress bar.
from tqdm.auto import tqdm
import random

## **Dataset, Data Loader, and Transforms** *(similar to HW3)*

Torchvision provides lots of useful utilities for image preprocessing, data wrapping as well as data augmentation.

Here, since our data are stored in folders by class labels, we can directly apply **torchvision.datasets.DatasetFolder** for wrapping data without much effort.

Please refer to [PyTorch official website](https://pytorch.org/vision/stable/transforms.html) for details about different transforms.

---
**The only diffference with HW3 is that the transform functions are different.**

In [14]:

train_tfm = transforms.Compose([
  # Resize the image into a fixed shape (height = width = 142)
    transforms.Resize((142, 142)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomCrop(128),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_tfm = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 142)
    transforms.Resize((142, 142)),
    transforms.CenterCrop(128),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])


In [15]:
### This block is similar to HW3 ###
# Batch size for training, validation, and testing.
# A greater batch size usually gives a more stable gradient.
# But the GPU memory is limited, so please adjust it carefully.
batch_size = 64

# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
path_root = '/kaggle/input/lastmlhw08/food-11'
train_set = DatasetFolder(path_root+"/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder(path_root+"/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder(path_root+"/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder(path_root+"/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [16]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            self.dwpw_conv(64,128,3),
            nn.MaxPool2d(2, 2, 0),
            
            self.dwpw_conv(128,256,3),
            nn.MaxPool2d(2, 2, 0),
            
            self.dwpw_conv(256,170,3),

            # Here we adopt Global Average Pooling for various input size.
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Sequential(
            nn.Linear(170, 32),
            nn.Linear(32, 11)
        )
      
    def forward(self, x):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        return self.fc(out)
    
    def dwpw_conv(self, in_chs, out_chs, k=3, s=1, p=0):
        return nn.Sequential(
            nn.Conv2d(in_chs, in_chs, k, s, p, groups=in_chs),
            nn.BatchNorm2d(in_chs),
            nn.ReLU(),
            nn.Conv2d(in_chs, out_chs, 1),
            nn.BatchNorm2d(out_chs),
            nn.ReLU(),
        )

## **Model Analysis**

Use `torchsummary` to get your model architecture (screenshot or pasting text are allowed.) and numbers of 
parameters, these two information should be submit to your NTU Cool questions.

Note that the number of parameters **should not greater than 100,000**, or you'll get penalty in this homework.

In [17]:
from torchsummary import summary

student_net = StudentNet()
summary(student_net, (3, 128, 128), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 126, 126]           1,792
       BatchNorm2d-2         [-1, 64, 126, 126]             128
              ReLU-3         [-1, 64, 126, 126]               0
         MaxPool2d-4           [-1, 64, 63, 63]               0
            Conv2d-5           [-1, 64, 61, 61]             640
       BatchNorm2d-6           [-1, 64, 61, 61]             128
              ReLU-7           [-1, 64, 61, 61]               0
            Conv2d-8          [-1, 128, 61, 61]           8,320
       BatchNorm2d-9          [-1, 128, 61, 61]             256
             ReLU-10          [-1, 128, 61, 61]               0
        MaxPool2d-11          [-1, 128, 30, 30]               0
           Conv2d-12          [-1, 128, 28, 28]           1,280
      BatchNorm2d-13          [-1, 128, 28, 28]             256
             ReLU-14          [-1, 128,

In [18]:
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.5, T=15):
    hard_loss = F.cross_entropy(outputs, labels) * (1. - alpha) 
    tl = F.softmax(teacher_outputs/T, dim=1)
    sl = F.log_softmax(outputs/T, dim=1)
    KL = nn.KLDivLoss(reduction='batchmean')
    soft_loss = alpha * T * T * KL(sl, tl)
    return hard_loss + soft_loss

## **Teacher Model Setting**
We provide a well-trained teacher model to help you knowledge distillation to student model.
Note that if you want to change the transform function, you should consider  if suitable for this well-trained teacher model.
* If you cannot successfully gdown, you can change a link. (Backup link is provided at the bottom of this colab tutorial).

In [19]:
# Download teacherNet
!gdown --id '1ni4IQUJ4bJJX4toHpvdXIu4VW3P29QrH' --output teacher_net.ckpt
# Load teacherNet
teacher_net = torch.load('./teacher_net.ckpt')
teacher_net.eval()

Downloading...
From (uriginal): https://drive.google.com/uc?id=1ni4IQUJ4bJJX4toHpvdXIu4VW3P29QrH
From (redirected): https://drive.google.com/uc?id=1ni4IQUJ4bJJX4toHpvdXIu4VW3P29QrH&confirm=t&uuid=a84e308c-1f61-422e-ae01-5597c8c85544
To: /kaggle/working/teacher_net.ckpt
100%|███████████████████████████████████████| 44.8M/44.8M [00:00<00:00, 196MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## **Generate Pseudo Labels in Unlabeled Data**

Since we have a well-trained model, we can use this model to predict pseudo-labels and help the student network train well. Note that you 
**CANNOT** use well-trained model to pseudo-label the test data. 


---

**AGAIN, DO NOT USE TEST DATA FOR PURPOSE OTHER THAN INFERENCING**

* Because If you use teacher network to predict pseudo-labels of the test data, you can only use student network to overfit these pseudo-labels without train/unlabeled data. In this way, your kaggle accuracy will be as high as the teacher network, but the fact is that you just overfit the test data and your true testing accuracy is very low. 
* These contradict the purpose of these assignment (network compression); therefore, you should not misuse the test data.
* If you have any concerns, you can email us.

In [20]:
# "cuda" only when GPUs are available.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize a model, and put it on the device specified.
student_net = student_net.to(device)
teacher_net = teacher_net.to(device)

# Whether to do pseudo label.
do_semi = True

def get_pseudo_labels(dataset, model):
    loader = DataLoader(dataset, batch_size=batch_size*3, shuffle=False, pin_memory=True)
    pseudo_labels = []
    for batch in tqdm(loader):
        # A batch consists of image data and corresponding labels.
        img, _ = batch

        # Forward the data
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(img.to(device))
            pseudo_labels.append(logits.argmax(dim=-1).detach().cpu())
        # Obtain the probability distributions by applying softmax on logits.
    pseudo_labels = torch.cat(pseudo_labels)
    # Update the labels by replacing with pseudo labels.
    for idx, ((img, _), pseudo_label) in enumerate(zip(dataset.samples, pseudo_labels)):
        dataset.samples[idx] = (img, pseudo_label.item())
    return dataset

if do_semi:
    # Generate new trainloader with unlabeled set.
    unlabeled_set = get_pseudo_labels(unlabeled_set, teacher_net)
    concat_dataset = ConcatDataset([train_set, unlabeled_set])
    train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)

  0%|          | 0/36 [00:00<?, ?it/s]

## **Training** *(similar to HW3)*

You can finish supervised learning by simply running the provided code without any modification.

The function "get_pseudo_labels" is used for semi-supervised learning.
It is expected to get better performance if you use unlabeled data for semi-supervised learning.
However, you have to implement the function on your own and need to adjust several hyperparameters manually.

For more details about semi-supervised learning, please refer to [Prof. Lee's slides](https://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2016/Lecture/semi%20(v3).pdf).

Again, please notice that utilizing external data (or pre-trained model) for training is **prohibited**.

---
**The only diffference with HW3 is that you should use loss in  knowledge distillation.**

In [21]:
import math
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR


def get_cosine_schedule_with_warmup(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    last_epoch: int = -1,
):
    def lr_lambda(current_step):
        # Warmup
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # decadence
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(
          0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
        )
    
    return LambdaLR(optimizer, lr_lambda, last_epoch)

In [22]:
criterion = nn.CrossEntropyLoss()
n_epochs = 100
optimizer = torch.optim.AdamW(student_net.parameters(), lr=0.001)
scheduler = get_cosine_schedule_with_warmup(optimizer, n_epochs/10, n_epochs)
best_acc = 0.0

for epoch in range(n_epochs):
    student_net.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):
        imgs, labels = batch
        logits = student_net(imgs.to(device))
        with torch.no_grad():
            soft_labels = teacher_net(imgs.to(device))
        
        loss = loss_fn_kd(logits, labels.to(device), soft_labels)
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(student_net.parameters(), max_norm=10)
        optimizer.step()
        scheduler.step()

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)

    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")


    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    student_net.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):
        imgs, labels = batch
        with torch.no_grad():
            logits = student_net(imgs.to(device))
            soft_labels = teacher_net(imgs.to(device))
        # We can still compute the loss (but not the gradient).
        loss = loss_fn_kd(logits, labels.to(device), soft_labels)
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().detach().cpu().view(-1).numpy()
        valid_loss.append(loss.item())
        valid_accs += list(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(student_net.state_dict(), './model.ckpt')

    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 001/100 ] loss = 8.71160, acc = 0.30702


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 001/100 ] loss = 8.02428, acc = 0.28939


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 002/100 ] loss = 7.32956, acc = 0.41741


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 002/100 ] loss = 6.72775, acc = 0.39091


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 003/100 ] loss = 6.55685, acc = 0.47555


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 003/100 ] loss = 5.97114, acc = 0.44394


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 004/100 ] loss = 6.00362, acc = 0.52121


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 004/100 ] loss = 5.42141, acc = 0.49091


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 005/100 ] loss = 5.57041, acc = 0.54809


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 005/100 ] loss = 5.10020, acc = 0.51818


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 006/100 ] loss = 5.22242, acc = 0.58553


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 006/100 ] loss = 5.40776, acc = 0.51667


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 007/100 ] loss = 5.02576, acc = 0.59720


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 007/100 ] loss = 5.09305, acc = 0.53333


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 008/100 ] loss = 4.81293, acc = 0.60735


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 008/100 ] loss = 4.74850, acc = 0.55303


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 009/100 ] loss = 4.67981, acc = 0.62023


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 009/100 ] loss = 4.41455, acc = 0.58485


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 010/100 ] loss = 4.55275, acc = 0.62896


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 010/100 ] loss = 4.25525, acc = 0.55606


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 011/100 ] loss = 4.39651, acc = 0.63657


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 011/100 ] loss = 4.26436, acc = 0.59545


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 012/100 ] loss = 4.20979, acc = 0.65300


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 012/100 ] loss = 4.33700, acc = 0.58030


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 013/100 ] loss = 4.11062, acc = 0.66437


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 013/100 ] loss = 4.29796, acc = 0.59545


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 014/100 ] loss = 4.03045, acc = 0.66700


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 014/100 ] loss = 4.34192, acc = 0.59848


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 015/100 ] loss = 3.95090, acc = 0.67411


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 015/100 ] loss = 4.28417, acc = 0.60909


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 016/100 ] loss = 3.91547, acc = 0.67806


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 016/100 ] loss = 3.78401, acc = 0.61667


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 017/100 ] loss = 3.84321, acc = 0.68405


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 017/100 ] loss = 3.66147, acc = 0.62727


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 018/100 ] loss = 3.72530, acc = 0.69369


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 018/100 ] loss = 3.64496, acc = 0.62879


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 019/100 ] loss = 3.63737, acc = 0.69795


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 019/100 ] loss = 3.82702, acc = 0.62273


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 020/100 ] loss = 3.54994, acc = 0.70769


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 020/100 ] loss = 3.91137, acc = 0.62273


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 021/100 ] loss = 3.49982, acc = 0.71084


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 021/100 ] loss = 4.40207, acc = 0.59242


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 022/100 ] loss = 3.52331, acc = 0.70373


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 022/100 ] loss = 3.60841, acc = 0.65000


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 023/100 ] loss = 3.46543, acc = 0.71439


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 023/100 ] loss = 3.27534, acc = 0.69091


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 024/100 ] loss = 3.37396, acc = 0.71642


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 024/100 ] loss = 3.22186, acc = 0.67879


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 025/100 ] loss = 3.28457, acc = 0.72149


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 025/100 ] loss = 3.37241, acc = 0.65909


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 026/100 ] loss = 3.27422, acc = 0.73082


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 026/100 ] loss = 3.77492, acc = 0.66061


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 027/100 ] loss = 3.18499, acc = 0.74290


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 027/100 ] loss = 3.60363, acc = 0.67273


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 028/100 ] loss = 3.15638, acc = 0.73478


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 028/100 ] loss = 3.45343, acc = 0.67424


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 029/100 ] loss = 3.11745, acc = 0.73894


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 029/100 ] loss = 3.13153, acc = 0.69242


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 030/100 ] loss = 3.14612, acc = 0.74077


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 030/100 ] loss = 3.06847, acc = 0.69242


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 031/100 ] loss = 3.08678, acc = 0.74280


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 031/100 ] loss = 3.00800, acc = 0.70152


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 032/100 ] loss = 3.04926, acc = 0.74564


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 032/100 ] loss = 3.15879, acc = 0.70303


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 033/100 ] loss = 2.98964, acc = 0.74959


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 033/100 ] loss = 3.37763, acc = 0.67727


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 034/100 ] loss = 2.94951, acc = 0.75964


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 034/100 ] loss = 3.67629, acc = 0.66212


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 035/100 ] loss = 2.91218, acc = 0.76481


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 035/100 ] loss = 3.39602, acc = 0.69545


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 036/100 ] loss = 2.93141, acc = 0.76096


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 036/100 ] loss = 3.05906, acc = 0.68636


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 037/100 ] loss = 2.87047, acc = 0.76319


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 037/100 ] loss = 2.85946, acc = 0.70909


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 038/100 ] loss = 2.88373, acc = 0.75802


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 038/100 ] loss = 2.87720, acc = 0.72273


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 039/100 ] loss = 2.82731, acc = 0.76360


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 039/100 ] loss = 3.11353, acc = 0.69848


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 040/100 ] loss = 2.78558, acc = 0.77212


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 040/100 ] loss = 3.39950, acc = 0.68636


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 041/100 ] loss = 2.75489, acc = 0.77577


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 041/100 ] loss = 3.12668, acc = 0.68485


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 042/100 ] loss = 2.76999, acc = 0.77121


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 042/100 ] loss = 2.99310, acc = 0.69848


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 043/100 ] loss = 2.74296, acc = 0.77608


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 043/100 ] loss = 2.82096, acc = 0.71818


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 044/100 ] loss = 2.74282, acc = 0.77699


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 044/100 ] loss = 2.66437, acc = 0.72879


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 045/100 ] loss = 2.68597, acc = 0.78247


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 045/100 ] loss = 2.77734, acc = 0.71212


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 046/100 ] loss = 2.68412, acc = 0.77973


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 046/100 ] loss = 2.77785, acc = 0.73939


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 047/100 ] loss = 2.58091, acc = 0.79038


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 047/100 ] loss = 3.16213, acc = 0.69242


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 048/100 ] loss = 2.60025, acc = 0.78967


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 048/100 ] loss = 3.28242, acc = 0.67121


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 049/100 ] loss = 2.58118, acc = 0.78886


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 049/100 ] loss = 2.90013, acc = 0.71667


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 050/100 ] loss = 2.60606, acc = 0.78632


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 050/100 ] loss = 2.65456, acc = 0.72424


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 051/100 ] loss = 2.60817, acc = 0.78531


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 051/100 ] loss = 2.63127, acc = 0.73333


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 052/100 ] loss = 2.58308, acc = 0.78764


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 052/100 ] loss = 2.70335, acc = 0.74242


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 053/100 ] loss = 2.53160, acc = 0.79119


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 053/100 ] loss = 2.79620, acc = 0.71212


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 054/100 ] loss = 2.50862, acc = 0.79373


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 054/100 ] loss = 2.94270, acc = 0.70909


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 055/100 ] loss = 2.47022, acc = 0.79576


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 055/100 ] loss = 3.09900, acc = 0.71970


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 056/100 ] loss = 2.47193, acc = 0.80083


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 056/100 ] loss = 2.99630, acc = 0.70909


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 057/100 ] loss = 2.50563, acc = 0.80103


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 057/100 ] loss = 2.70113, acc = 0.73030


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 058/100 ] loss = 2.51895, acc = 0.79282


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 058/100 ] loss = 2.59437, acc = 0.74848


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 059/100 ] loss = 2.47906, acc = 0.79566


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 059/100 ] loss = 2.57894, acc = 0.74242


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 060/100 ] loss = 2.44532, acc = 0.80175


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 060/100 ] loss = 2.94732, acc = 0.72273


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 061/100 ] loss = 2.45785, acc = 0.80712


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 061/100 ] loss = 2.96063, acc = 0.71667


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 062/100 ] loss = 2.36864, acc = 0.81108


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 062/100 ] loss = 2.91213, acc = 0.71667


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 063/100 ] loss = 2.38026, acc = 0.80540


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 063/100 ] loss = 2.68239, acc = 0.73485


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 064/100 ] loss = 2.38950, acc = 0.80357


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 064/100 ] loss = 2.53923, acc = 0.72273


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 065/100 ] loss = 2.39584, acc = 0.80651


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 065/100 ] loss = 2.54028, acc = 0.75455


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 066/100 ] loss = 2.35846, acc = 0.81088


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 066/100 ] loss = 2.55414, acc = 0.73636


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 067/100 ] loss = 2.34776, acc = 0.81179


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 067/100 ] loss = 2.68702, acc = 0.72727


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 068/100 ] loss = 2.30862, acc = 0.81605


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 068/100 ] loss = 3.14793, acc = 0.72879


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 069/100 ] loss = 2.29418, acc = 0.81686


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 069/100 ] loss = 2.78602, acc = 0.72424


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 070/100 ] loss = 2.31193, acc = 0.81280


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 070/100 ] loss = 2.72670, acc = 0.74091


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 071/100 ] loss = 2.32374, acc = 0.81311


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 071/100 ] loss = 2.55472, acc = 0.75152


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 072/100 ] loss = 2.33233, acc = 0.81047


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 072/100 ] loss = 2.48002, acc = 0.74091


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 073/100 ] loss = 2.30560, acc = 0.80996


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 073/100 ] loss = 2.57374, acc = 0.75303


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 074/100 ] loss = 2.25547, acc = 0.81351


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 074/100 ] loss = 2.74001, acc = 0.73182


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 075/100 ] loss = 2.23377, acc = 0.82305


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 075/100 ] loss = 2.87876, acc = 0.73636


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 076/100 ] loss = 2.23945, acc = 0.82336


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 076/100 ] loss = 2.84690, acc = 0.72727


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 077/100 ] loss = 2.27085, acc = 0.81778


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 077/100 ] loss = 2.81868, acc = 0.71818


  0%|          | 0/154 [00:00<?, ?it/s]

[ Train | 078/100 ] loss = 2.26074, acc = 0.81625


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 078/100 ] loss = 2.48177, acc = 0.73485


  0%|          | 0/154 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [23]:
test_model = StudentNet().to(device)
test_model.load_state_dict(torch.load('./model.ckpt'))
test_model.eval()

# Initialize a list to store the predictions.
predictions = []

# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    imgs, labels = batch
    with torch.no_grad():
        logits = test_model(imgs.to(device))

    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

  0%|          | 0/53 [00:00<?, ?it/s]

In [24]:
with open("predict.csv", "w") as f:
    f.write("Id,Category\n")
    for i, pred in  enumerate(predictions):
         f.write(f"{i},{pred}\n")