In [1]:
import os
from typing import Tuple, Sequence, Callable
import csv
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn.init as init
import torch.optim as optim
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

from torchvision import transforms
from torchvision.models import resnet50
import random

In [2]:
!nvidia-smi
print(torch.cuda.is_available())

Sat Feb 13 00:07:41 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 306...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   33C    P0    46W / 220W |    222MiB /  7982MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

In [4]:
seed_everything(42)

dataset_path = '/home/gwonil/datasets/mnist_multilabel/'
weight_path = '/home/gwonil/Mylab/Coding/Dacon_competition/multi-label_classification/weights'

## 1. 커스텀 데이터셋 만들기

In [5]:
def split_dataset(path):
    df = pd.read_csv(path)
    kfold = KFold(n_splits=5)
    for fold, (train, valid) in enumerate(kfold.split(df, df.index)):
        df.loc[valid, 'kfold'] = int(fold)
        
    df.to_csv(dataset_path + 'split_kfold.csv')

In [6]:
class MnistDataset(Dataset):
    def __init__(
        self,
        dir: os.PathLike,
        image_ids: os.PathLike,
        transforms: Sequence[Callable]
    ) -> None:
        self.dir = dir
        self.transforms = transforms

        self.labels = {}
        with open(image_ids, 'r') as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                self.labels[int(row[0])] = list(map(int, row[1:]))

        self.image_ids = list(self.labels.keys())

    def __len__(self) -> int:
        return len(self.image_ids)

    def __getitem__(self, index: int) -> Tuple[Tensor]:
        image_id = self.image_ids[index]
        image = Image.open(
            os.path.join(
                self.dir, f'{str(image_id).zfill(5)}.png')).convert('RGB')
        target = np.array(self.labels.get(image_id)).astype(np.float32)

        if self.transforms is not None:
            image = self.transforms(image)

        return image, target

## 2. 이미지 어그멘테이션

In [14]:
transforms_train = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    #transforms.RandomRotation(30, expand=False),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

transforms_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

In [15]:
trainset = MnistDataset(dataset_path + 'dirty_mnist_2nd', dataset_path + 'dirty_mnist_2nd_answer.csv', transforms_train)
testset = MnistDataset(dataset_path + 'test_dirty_mnist_2nd', dataset_path + 'sample_submission.csv', transforms_test)

train_loader = DataLoader(trainset, batch_size=32, num_workers=4)
test_loader = DataLoader(testset, batch_size=32, num_workers=4)

## 3. ResNet50 모형

In [16]:
class CNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.resnet = resnet50(pretrained=True)
        self.classifier = nn.Linear(1000, 26)
        
        #for m in self.modules():
        #    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        #        init.xavier_normal_(m.weight)

    def forward(self, x):
        x = self.resnet(x)
        x = self.classifier(x)

        return x

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
print(summary(model, input_size=(1, 3, 256, 256), verbose=0))
print(device)

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       9,408
|    └─BatchNorm2d: 2-2                  128
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─Bottleneck: 3-1              75,008
|    |    └─Bottleneck: 3-2              70,400
|    |    └─Bottleneck: 3-3              70,400
|    └─Sequential: 2-6                   --
|    |    └─Bottleneck: 3-4              379,392
|    |    └─Bottleneck: 3-5              280,064
|    |    └─Bottleneck: 3-6              280,064
|    |    └─Bottleneck: 3-7              280,064
|    └─Sequential: 2-7                   --
|    |    └─Bottleneck: 3-8              1,512,448
|    |    └─Bottleneck: 3-9              1,117,184
|    |    └─Bottleneck: 3-10             1,117,184
|    |    └─Bottleneck: 3-11             1,117,184
|    |    └─Bottleneck: 3-12             1,117,184


## 4. 학습하기

In [17]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MultiLabelSoftMarginLoss()

# loaded trained model
#model.load_state_dict(torch.load("model_sample.pth"))

num_epochs = 10
model.train()

for epoch in range(num_epochs):
    for i, (images, targets) in enumerate(train_loader):
        optimizer.zero_grad()

        images = images.to(device)
        targets = targets.to(device)

        outputs = model(images)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            outputs = outputs > 0.5
            acc = (outputs == targets).float().mean()
            print(f'{epoch+1}[{i+1:4d}/{len(train_loader)}]: {loss.item():.5f}, {acc.item():.5f}')

1[  10/1563]: 0.73619, 0.52163
1[  20/1563]: 0.71763, 0.53966
1[  30/1563]: 0.70329, 0.54087
1[  40/1563]: 0.69653, 0.54207
1[  50/1563]: 0.68953, 0.53245
1[  60/1563]: 0.69399, 0.54447
1[  70/1563]: 0.69594, 0.53486
1[  80/1563]: 0.69710, 0.52885
1[  90/1563]: 0.69041, 0.54928
1[ 100/1563]: 0.69558, 0.54207
1[ 110/1563]: 0.69266, 0.54327
1[ 120/1563]: 0.69408, 0.52885
1[ 130/1563]: 0.69343, 0.54447
1[ 140/1563]: 0.69776, 0.53606
1[ 150/1563]: 0.69661, 0.53486
1[ 160/1563]: 0.69526, 0.54087
1[ 170/1563]: 0.69648, 0.53846
1[ 180/1563]: 0.69315, 0.52043
1[ 190/1563]: 0.69103, 0.54447
1[ 200/1563]: 0.69078, 0.54327
1[ 210/1563]: 0.69395, 0.53245
1[ 220/1563]: 0.69085, 0.55048
1[ 230/1563]: 0.68876, 0.54567
1[ 240/1563]: 0.69300, 0.53606
1[ 250/1563]: 0.69180, 0.53966
1[ 260/1563]: 0.69097, 0.53125
1[ 270/1563]: 0.69341, 0.52885
1[ 280/1563]: 0.69098, 0.54087
1[ 290/1563]: 0.69644, 0.53365
1[ 300/1563]: 0.68982, 0.52644
1[ 310/1563]: 0.68745, 0.53846
1[ 320/1563]: 0.68921, 0.54928
1[ 330/1

2[1100/1563]: 0.66971, 0.56611
2[1110/1563]: 0.67759, 0.54928
2[1120/1563]: 0.67762, 0.55529
2[1130/1563]: 0.66892, 0.55529
2[1140/1563]: 0.67407, 0.57212
2[1150/1563]: 0.66627, 0.55529
2[1160/1563]: 0.66733, 0.56370
2[1170/1563]: 0.66295, 0.57933
2[1180/1563]: 0.67584, 0.56490
2[1190/1563]: 0.66861, 0.55649
2[1200/1563]: 0.67021, 0.55048
2[1210/1563]: 0.66956, 0.57212
2[1220/1563]: 0.67011, 0.56490
2[1230/1563]: 0.67957, 0.55168
2[1240/1563]: 0.67160, 0.56370
2[1250/1563]: 0.67830, 0.55168
2[1260/1563]: 0.66566, 0.55769
2[1270/1563]: 0.67383, 0.53846
2[1280/1563]: 0.66467, 0.56130
2[1290/1563]: 0.66016, 0.56851
2[1300/1563]: 0.67480, 0.55168
2[1310/1563]: 0.66727, 0.56851
2[1320/1563]: 0.67004, 0.57091
2[1330/1563]: 0.66633, 0.55048
2[1340/1563]: 0.67363, 0.56370
2[1350/1563]: 0.66696, 0.55409
2[1360/1563]: 0.67006, 0.55529
2[1370/1563]: 0.65263, 0.56250
2[1380/1563]: 0.67560, 0.55649
2[1390/1563]: 0.66050, 0.53726
2[1400/1563]: 0.66797, 0.55889
2[1410/1563]: 0.66418, 0.57091
2[1420/1

4[ 630/1563]: 0.63231, 0.62019
4[ 640/1563]: 0.62336, 0.61538
4[ 650/1563]: 0.64470, 0.58654
4[ 660/1563]: 0.63257, 0.59856
4[ 670/1563]: 0.60459, 0.64663
4[ 680/1563]: 0.64136, 0.59856
4[ 690/1563]: 0.62951, 0.61538
4[ 700/1563]: 0.62418, 0.63101
4[ 710/1563]: 0.62892, 0.60697
4[ 720/1563]: 0.62684, 0.62740
4[ 730/1563]: 0.63491, 0.60216
4[ 740/1563]: 0.62053, 0.61298
4[ 750/1563]: 0.61102, 0.61779
4[ 760/1563]: 0.62039, 0.59375
4[ 770/1563]: 0.61534, 0.61418
4[ 780/1563]: 0.61718, 0.61899
4[ 790/1563]: 0.63505, 0.61058
4[ 800/1563]: 0.63914, 0.58534
4[ 810/1563]: 0.62689, 0.60457
4[ 820/1563]: 0.60742, 0.64062
4[ 830/1563]: 0.63314, 0.61659
4[ 840/1563]: 0.62490, 0.61899
4[ 850/1563]: 0.61673, 0.64062
4[ 860/1563]: 0.60227, 0.62740
4[ 870/1563]: 0.62570, 0.62380
4[ 880/1563]: 0.62790, 0.63341
4[ 890/1563]: 0.63395, 0.60096
4[ 900/1563]: 0.61623, 0.60938
4[ 910/1563]: 0.60719, 0.62861
4[ 920/1563]: 0.60819, 0.63942
4[ 930/1563]: 0.62393, 0.62500
4[ 940/1563]: 0.62942, 0.61178
4[ 950/1

6[ 160/1563]: 0.60801, 0.63702
6[ 170/1563]: 0.59091, 0.65745
6[ 180/1563]: 0.60392, 0.62740
6[ 190/1563]: 0.59155, 0.65865
6[ 200/1563]: 0.54598, 0.70072
6[ 210/1563]: 0.55771, 0.67428
6[ 220/1563]: 0.57466, 0.68389
6[ 230/1563]: 0.54066, 0.70913
6[ 240/1563]: 0.56813, 0.66947
6[ 250/1563]: 0.57617, 0.67067
6[ 260/1563]: 0.58078, 0.65264
6[ 270/1563]: 0.58465, 0.67668
6[ 280/1563]: 0.56115, 0.69231
6[ 290/1563]: 0.58956, 0.66466
6[ 300/1563]: 0.60227, 0.63822
6[ 310/1563]: 0.57722, 0.65505
6[ 320/1563]: 0.58534, 0.66707
6[ 330/1563]: 0.56005, 0.68510
6[ 340/1563]: 0.58653, 0.65385
6[ 350/1563]: 0.57122, 0.67668
6[ 360/1563]: 0.60158, 0.63221
6[ 370/1563]: 0.57815, 0.67308
6[ 380/1563]: 0.58523, 0.65505
6[ 390/1563]: 0.55595, 0.69351
6[ 400/1563]: 0.57827, 0.65625
6[ 410/1563]: 0.58161, 0.66466
6[ 420/1563]: 0.57521, 0.67548
6[ 430/1563]: 0.57136, 0.65986
6[ 440/1563]: 0.54805, 0.68389
6[ 450/1563]: 0.56874, 0.68389
6[ 460/1563]: 0.61551, 0.63942
6[ 470/1563]: 0.58125, 0.66106
6[ 480/1

7[1250/1563]: 0.51737, 0.72115
7[1260/1563]: 0.49749, 0.73197
7[1270/1563]: 0.53465, 0.69351
7[1280/1563]: 0.53471, 0.70913
7[1290/1563]: 0.52485, 0.69832
7[1300/1563]: 0.57228, 0.68630
7[1310/1563]: 0.51226, 0.72356
7[1320/1563]: 0.53706, 0.73077
7[1330/1563]: 0.47778, 0.74399
7[1340/1563]: 0.50486, 0.73438
7[1350/1563]: 0.53067, 0.71394
7[1360/1563]: 0.53947, 0.70072
7[1370/1563]: 0.51103, 0.71635
7[1380/1563]: 0.54778, 0.71875
7[1390/1563]: 0.49887, 0.71394
7[1400/1563]: 0.53652, 0.71154
7[1410/1563]: 0.52855, 0.70673
7[1420/1563]: 0.50964, 0.71514
7[1430/1563]: 0.52605, 0.71274
7[1440/1563]: 0.54073, 0.69952
7[1450/1563]: 0.50202, 0.73798
7[1460/1563]: 0.51953, 0.72476
7[1470/1563]: 0.51077, 0.73438
7[1480/1563]: 0.50537, 0.73798
7[1490/1563]: 0.52622, 0.72115
7[1500/1563]: 0.51912, 0.70913
7[1510/1563]: 0.51289, 0.73558
7[1520/1563]: 0.51110, 0.75000
7[1530/1563]: 0.54779, 0.71394
7[1540/1563]: 0.47260, 0.76803
7[1550/1563]: 0.55345, 0.69111
7[1560/1563]: 0.53656, 0.71995
8[  10/1

9[ 780/1563]: 0.48073, 0.76923
9[ 790/1563]: 0.47903, 0.76442
9[ 800/1563]: 0.47751, 0.75721
9[ 810/1563]: 0.47804, 0.76562
9[ 820/1563]: 0.44335, 0.77284
9[ 830/1563]: 0.44048, 0.78606
9[ 840/1563]: 0.49049, 0.75721
9[ 850/1563]: 0.46132, 0.77163
9[ 860/1563]: 0.41989, 0.79327
9[ 870/1563]: 0.45595, 0.77524
9[ 880/1563]: 0.47034, 0.76803
9[ 890/1563]: 0.47112, 0.75481
9[ 900/1563]: 0.48655, 0.73918
9[ 910/1563]: 0.47583, 0.75841
9[ 920/1563]: 0.43710, 0.78005
9[ 930/1563]: 0.49062, 0.75361
9[ 940/1563]: 0.44774, 0.77284
9[ 950/1563]: 0.46497, 0.76202
9[ 960/1563]: 0.47104, 0.76082
9[ 970/1563]: 0.47028, 0.77644
9[ 980/1563]: 0.45411, 0.77644
9[ 990/1563]: 0.50394, 0.74880
9[1000/1563]: 0.43940, 0.77284
9[1010/1563]: 0.45912, 0.76803
9[1020/1563]: 0.46780, 0.76923
9[1030/1563]: 0.48297, 0.75841
9[1040/1563]: 0.47965, 0.77524
9[1050/1563]: 0.44341, 0.78486
9[1060/1563]: 0.43946, 0.77404
9[1070/1563]: 0.48147, 0.75721
9[1080/1563]: 0.45206, 0.78365
9[1090/1563]: 0.45536, 0.78245
9[1100/1

## 5. 추론하기

In [18]:
submit = pd.read_csv('/home/gwonil/datasets/mnist_multilabel/sample_submission.csv')

model.eval()
batch_size = test_loader.batch_size
batch_index = 0
for i, (images, targets) in enumerate(test_loader):
    images = images.to(device)
    targets = targets.to(device)
    outputs = model(images)
    outputs = outputs > 0.5
    batch_index = i * batch_size
    submit.iloc[batch_index:batch_index+batch_size, 1:] = \
        outputs.long().squeeze(0).detach().cpu().numpy()
    
submit.to_csv('submit_sample2.csv', index=False)

## 6. saving & loading trained model

In [19]:
# Save model
torch.save(model.state_dict(), "weights/model_sample2.pth")

# Load model
#model.load_state_dict(torch.load("model_sample.pth"))

In [14]:
# free cuda cache
torch.cuda.empty_cache()