In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
import torchvision.models as models

from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder
from PIL import Image

from tqdm.auto import tqdm

In [2]:
train_tfm = transforms.Compose([
    transforms.Resize((142, 142)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomCrop(128),
    transforms.ToTensor()
])

test_tfm = transforms.Compose([
    transforms.Resize((142, 142)),
    transforms.CenterCrop(128),
    transforms.ToTensor()
])

In [3]:
labeled_set = DatasetFolder('./food-11/training/labeled', loader=lambda x : Image.open(x), extensions = 'jpg', transform = train_tfm)
unlabeled_set = DatasetFolder('./food-11/training/unlabeled', loader=lambda x : Image.open(x), extensions = 'jpg', transform = train_tfm)
valid_set = DatasetFolder('./food-11/validation', loader=lambda x : Image.open(x), extensions = 'jpg', transform = test_tfm)
test_set = DatasetFolder('./food-11/testing', loader=lambda x : Image.open(x), extensions = 'jpg', transform = test_tfm)

batch_size = 64

train_loader = DataLoader(labeled_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [4]:
def dwpw_conv(in_chs, out_chs, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_chs, in_chs, kernel_size, stride, padding, groups=in_chs),
        nn.BatchNorm2d(in_chs),
        nn.ReLU(),
        nn.Conv2d(in_chs, out_chs, 1),
        nn.MaxPool2d(2),
    )

In [5]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.cnn = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(3, 64, 3, 1, 0),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(2),
            ),
            
            dwpw_conv(64, 128, 3, 1, 0),
            dwpw_conv(128, 256, 3, 1, 0),
            
            nn.Sequential(
                nn.Conv2d(256, 256, 3, 1, 0, groups=256),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, 150, 1),
            ),
            
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(150, 64),
            nn.ReLU(),
            nn.Linear(64, 11),
        )
    
    def forward(self, x):
        x = self.cnn(x)
        out = self.fc(x.squeeze())
        return out

In [6]:
from torchsummary import summary

summary(StudentNet(), (3, 128, 128), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 126, 126]           1,792
       BatchNorm2d-2         [-1, 64, 126, 126]             128
              ReLU-3         [-1, 64, 126, 126]               0
         MaxPool2d-4           [-1, 64, 63, 63]               0
            Conv2d-5           [-1, 64, 61, 61]             640
       BatchNorm2d-6           [-1, 64, 61, 61]             128
              ReLU-7           [-1, 64, 61, 61]               0
            Conv2d-8          [-1, 128, 61, 61]           8,320
         MaxPool2d-9          [-1, 128, 30, 30]               0
           Conv2d-10          [-1, 128, 28, 28]           1,280
      BatchNorm2d-11          [-1, 128, 28, 28]             256
             ReLU-12          [-1, 128, 28, 28]               0
           Conv2d-13          [-1, 256, 28, 28]          33,024
        MaxPool2d-14          [-1, 256,

In [11]:
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.5, T=20):
    hard_loss = (1.0 - alpha) * F.cross_entropy(outputs, labels)
    criterion = nn.KLDivLoss(reduction='batchmean')
    soft_loss = criterion(F.log_softmax(outputs/T, dim=-1), F.softmax(teacher_outputs/T, dim=-1)) * alpha * T * T
    return soft_loss + hard_loss

In [8]:
teacher_net = torch.load('./teacher_net.ckpt')
teacher_net.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

student_net = StudentNet().to(device)
teacher_net = teacher_net.to(device)

do_semi = True

def get_pseudo_labels(dataset, model):
    model.eval()
    
    pseudo_labels = []
    
    loader = DataLoader(dataset, batch_size=batch_size*3, shuffle=False)
    
    for imgs, _ in tqdm(loader):
        imgs = imgs.to(device)
        
        with torch.no_grad():
            logits = model(imgs)
            pseudo_labels.append(logits.argmax(dim=-1).detach().cpu())
        
    pseudo_labels = torch.cat(pseudo_labels)
    
    for idx, ((img, _), label) in enumerate(zip(dataset.samples, pseudo_labels)):
        dataset.samples[idx] = (img, label.item())
    
    return dataset

if do_semi:
    dataset = get_pseudo_labels(unlabeled_set, teacher_net)
    train_set = ConcatDataset([labeled_set, dataset])
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)

Widget Javascript not detected.  It may not be installed or enabled properly.





In [12]:
optimizer = torch.optim.Adam(student_net.parameters(), lr=0.001, weight_decay=1e-4)

n_epoch = 100
best_valid_acc = 0.0

for epoch in range(n_epoch):
    student_net.train()
    train_acc, train_loss = [], []
    
    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        logits = student_net(imgs)
        with torch.no_grad():
            teacher_logits = teacher_net(imgs)
        
        loss = loss_fn_kd(logits, labels, teacher_logits)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
        train_acc.append((logits.argmax(dim=-1) == labels).detach().cpu().float().mean())
        
    acc = sum(train_acc) / len(train_acc)
    loss = sum(train_loss) / len(train_loss)
    
    print(f"[ Train | {epoch + 1:03d}/{n_epoch:03d} ] loss = {loss:.5f}, acc = {acc:.5f}")
    
    student_net.eval()
    
    valid_acc, valid_loss = [], []
    for imgs, labels in tqdm(valid_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        logits = student_net(imgs)
        with torch.no_grad():
            teacher_logits = teacher_net(imgs)
        
        loss = loss_fn_kd(logits, labels, teacher_logits)
        
        valid_loss.append(loss.item())
        acc = (logits.argmax(dim=-1) == labels).detach().cpu().float().numpy()
        
        valid_acc += list(acc)
    
    loss = sum(valid_loss) / len(valid_loss)
    acc = sum(valid_acc) / len(valid_acc)
    
    print(f"[ Valid | {epoch + 1:03d}/{n_epoch:03d} ] loss = {loss:.5f}, acc = {acc:.5f}")
    
    if acc > best_valid_acc:
        print(f'Validation accuracy improve from {best_valid_acc:.5f} to {acc:.5f} at epoch {epoch}')
        best_valid_acc = acc
        torch.save(student_net.state_dict(), './student_net_best.ckpt')

Widget Javascript not detected.  It may not be installed or enabled properly.


RuntimeError: CUDA out of memory. Tried to allocate 250.00 MiB (GPU 0; 4.00 GiB total capacity; 1.68 GiB already allocated; 0 bytes free; 1.74 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
student_net = StudentNet().to(device)
student_net.load_state_dict(torch.load('./student_net_best.ckpt'))

student_net.eval()
predictions = []

for imgs, _ in tqdm(test_loader):
    imgs = imgs.to(device)
    
    with torch.no_grad():
        logits = student_net(imgs)
    
    predictions.extend(logits.argmax(dim=-1).detach().cpu().numpy().tolist())

In [None]:
with open('predict.csv', 'w') as f:
    f.write("Id,Category\n")
    
    for i, pred in enumerate(predictions):
        f.write(f'{i},{pred}\n')