# CNN

kaggle: https://www.kaggle.com/competitions/ml2021spring-hw3/overview

In [39]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset, Dataset, Subset
from PIL import Image
# # 合并数据集用
import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder
from torchvision import models

from tqdm.auto import tqdm

## Dataset, DataLoader, Transforms

In [40]:
train_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    
    transforms.RandomHorizontalFlip(), # 隨機將圖片水平翻轉
    transforms.RandomRotation(15), # 隨機旋轉圖片
    transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

])

test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

])

In [41]:
batch_size = 128

# 构建数据集
train_set = DatasetFolder('./food-11/training/labeled', loader=lambda x: Image.open(x), extensions='jpg', transform=train_tfm)
val_set = DatasetFolder('./food-11/validation', loader=lambda x: Image.open(x), extensions='jpg', transform=test_tfm)
unlabeled_set = DatasetFolder("./food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder('./food-11/testing', loader=lambda x: Image.open(x), extensions='jpg', transform=test_tfm)

# 构建数据加载器
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

## Model

In [42]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        # input image size: [3, 128, 128]
        self.model = models.resnet18(pretrained=False)
        self.model.fc = nn.Linear(512, 11)

    def forward(self, x):
        return self.model(x)

## Training

使用 semi-supervised learning 提高performance

In [43]:
class PseudoDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


def get_pseudo_label(dataset, model, threshold=0.65):
    # This functions generates pseudo-labels of a dataset using given model.
    # It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model.eval()
    softmax = nn.Softmax(dim=1)
    for i, batch in tqdm(data_loader):
        img, _ = batch
        with torch.no_grad():
            logit = model(img.to(device))
        probs = softmax(logit)  # size: batch_size x 11
        
        idx = []
        labels = []
        for j, x in enumerate(probs):
            if torch.max(x) > threshold:
                idx.append(i * batch + j)
                labels.append(int(torch.argmax(x)))

    model.train()
    dataset = PseudoDataset(Subset(dataset, idx), labels)
    return dataset

In [46]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Classifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
epochs = 50
do_semi = False

for epoch in range(epochs):
    if do_semi:
      # ---------for semi-supervised learning--------
      # Not finish
        pseudo_set = get_pseudo_label(train_set, model)
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True)
      # ---------for semi-supervised learning--------


    model.train()
    train_loss = []
    train_acc = []

    for batch in tqdm(train_loader):
        imgs, labels = batch
        logits = model(imgs.to(device))
        loss = criterion(logits, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        # 梯度剪裁
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
        train_loss.append(loss.item())
        train_acc.append(acc)

    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_acc) / len(train_acc)
    
    print(f"[ Train | {epoch + 1:03d}/{epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")


    model.eval()
    valid_loss = []
    valid_acc = []
    for batch in tqdm(valid_loader):
        imgs, labels = batch
        with torch.no_grad():
            logits = model(imgs.to(device))
            loss = criterion(logits, labels.to(device))
            acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
            valid_loss.append(loss.item())
            valid_acc.append(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_acc) / len(valid_acc)

    print(f"[ Valid | {epoch + 1:03d}/{epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 001/050 ] loss = 2.38299, acc = 0.16813


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 001/050 ] loss = 2.45667, acc = 0.13125


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 002/050 ] loss = 2.23381, acc = 0.19406


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 002/050 ] loss = 2.59197, acc = 0.17188


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 003/050 ] loss = 2.14473, acc = 0.23656


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 003/050 ] loss = 2.14699, acc = 0.21276


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 004/050 ] loss = 2.09158, acc = 0.24469


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 004/050 ] loss = 2.91965, acc = 0.17682


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 005/050 ] loss = 2.06704, acc = 0.27094


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 005/050 ] loss = 1.98041, acc = 0.33906


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 006/050 ] loss = 2.03613, acc = 0.29094


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 006/050 ] loss = 2.08359, acc = 0.30313


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 007/050 ] loss = 1.96241, acc = 0.30906


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 007/050 ] loss = 2.10511, acc = 0.30625


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 008/050 ] loss = 1.94178, acc = 0.31812


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 008/050 ] loss = 1.89671, acc = 0.35078


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 009/050 ] loss = 1.91333, acc = 0.32031


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 009/050 ] loss = 1.74689, acc = 0.40260


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 010/050 ] loss = 1.90494, acc = 0.34219


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 010/050 ] loss = 1.86795, acc = 0.39792


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 011/050 ] loss = 1.89063, acc = 0.33406


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 011/050 ] loss = 2.05966, acc = 0.38880


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 012/050 ] loss = 1.85455, acc = 0.36031


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 012/050 ] loss = 2.05478, acc = 0.36380


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 013/050 ] loss = 1.83202, acc = 0.34781


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 013/050 ] loss = 1.89437, acc = 0.36250


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 014/050 ] loss = 1.77862, acc = 0.38719


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 014/050 ] loss = 1.74123, acc = 0.40885


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 015/050 ] loss = 1.76677, acc = 0.38375


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 015/050 ] loss = 1.53450, acc = 0.47682


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 016/050 ] loss = 1.71840, acc = 0.40500


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 016/050 ] loss = 1.78355, acc = 0.43932


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 017/050 ] loss = 1.69121, acc = 0.41562


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 017/050 ] loss = 2.00163, acc = 0.32917


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 018/050 ] loss = 1.69619, acc = 0.41781


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 018/050 ] loss = 1.83990, acc = 0.39219


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 019/050 ] loss = 1.76681, acc = 0.39781


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 019/050 ] loss = 1.88053, acc = 0.38698


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 020/050 ] loss = 1.70613, acc = 0.41250


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 020/050 ] loss = 1.64927, acc = 0.45495


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 021/050 ] loss = 1.68747, acc = 0.42531


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 021/050 ] loss = 1.35917, acc = 0.54766


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 022/050 ] loss = 1.61936, acc = 0.45125


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 022/050 ] loss = 1.46985, acc = 0.53177


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 023/050 ] loss = 1.58219, acc = 0.44625


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 023/050 ] loss = 1.60972, acc = 0.46068


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 024/050 ] loss = 1.54299, acc = 0.48344


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 024/050 ] loss = 1.68964, acc = 0.44896


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 025/050 ] loss = 1.61590, acc = 0.44625


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 025/050 ] loss = 1.77521, acc = 0.44375


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 026/050 ] loss = 1.54954, acc = 0.47594


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 026/050 ] loss = 1.40954, acc = 0.54974


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 027/050 ] loss = 1.51757, acc = 0.47875


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 027/050 ] loss = 1.39864, acc = 0.55443


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 028/050 ] loss = 1.52356, acc = 0.47875


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 028/050 ] loss = 1.50329, acc = 0.49375


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 029/050 ] loss = 1.52104, acc = 0.47812


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 029/050 ] loss = 1.45491, acc = 0.50990


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 030/050 ] loss = 1.42095, acc = 0.52312


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 030/050 ] loss = 1.31656, acc = 0.55964


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 031/050 ] loss = 1.47068, acc = 0.49281


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 031/050 ] loss = 1.59765, acc = 0.49375


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 032/050 ] loss = 1.42298, acc = 0.51313


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 032/050 ] loss = 1.56875, acc = 0.49193


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 033/050 ] loss = 1.41129, acc = 0.52875


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 033/050 ] loss = 1.31634, acc = 0.56354


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 034/050 ] loss = 1.36498, acc = 0.53969


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 034/050 ] loss = 1.31211, acc = 0.56354


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 035/050 ] loss = 1.32918, acc = 0.54531


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 035/050 ] loss = 1.31301, acc = 0.58255


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 036/050 ] loss = 1.33957, acc = 0.55531


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 036/050 ] loss = 1.31520, acc = 0.57135


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 037/050 ] loss = 1.36529, acc = 0.53438


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 037/050 ] loss = 1.49125, acc = 0.51667


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 038/050 ] loss = 1.32257, acc = 0.56563


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 038/050 ] loss = 1.22026, acc = 0.60365


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 039/050 ] loss = 1.36322, acc = 0.54094


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 039/050 ] loss = 1.26494, acc = 0.60182


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 040/050 ] loss = 1.27883, acc = 0.56344


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 040/050 ] loss = 1.39430, acc = 0.54375


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 041/050 ] loss = 1.19308, acc = 0.60844


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 041/050 ] loss = 1.26931, acc = 0.58385


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 042/050 ] loss = 1.24563, acc = 0.57000


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 042/050 ] loss = 1.28247, acc = 0.57214


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 043/050 ] loss = 1.21741, acc = 0.59219


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 043/050 ] loss = 1.36916, acc = 0.57031


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 044/050 ] loss = 1.17923, acc = 0.60062


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 044/050 ] loss = 1.40367, acc = 0.55599


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 045/050 ] loss = 1.19578, acc = 0.60438


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 045/050 ] loss = 1.89416, acc = 0.48047


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 046/050 ] loss = 1.11666, acc = 0.62500


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 046/050 ] loss = 1.24931, acc = 0.62448


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 047/050 ] loss = 1.14374, acc = 0.61344


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 047/050 ] loss = 1.33721, acc = 0.55599


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 048/050 ] loss = 1.08923, acc = 0.63563


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 048/050 ] loss = 1.24974, acc = 0.57682


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 049/050 ] loss = 1.00570, acc = 0.66062


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 049/050 ] loss = 1.24270, acc = 0.59375


  0%|          | 0/25 [00:00<?, ?it/s]

[ Train | 050/050 ] loss = 0.99852, acc = 0.66875


  0%|          | 0/6 [00:00<?, ?it/s]

[ Valid | 050/050 ] loss = 1.35193, acc = 0.59688


## Testing

In [47]:
model.eval()

predictions = []

for batch in tqdm(test_loader):
    # A batch consists of image data and corresponding labels.
    # But here the variable "labels" is useless since we do not have the ground-truth.
    # If printing out the labels, you will find that it is always 0.
    # This is because the wrapper (DatasetFolder) returns images and labels for each batch,
    # so we have to create fake labels to make it work normally.
    imgs, labels = batch

    with torch.no_grad():
        logits = model(imgs.to(device))

    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

predictions[:10]

  0%|          | 0/27 [00:00<?, ?it/s]

[2, 10, 4, 9, 7, 3, 4, 10, 0, 2]

In [48]:
with open("predict.csv", "w") as f:
    f.write("Id,Category\n")

    for i, pred in  enumerate(predictions):
         f.write(f"{i},{pred}\n")

In [51]:
!kaggle competitions submit ml2021spring-hw3 -f ./predict.csv -m 'ResNet18 with no pre-train'

100% 22.1k/22.1k [00:01<00:00, 19.8kB/s]
Successfully submitted to ML2021spring - hw3