<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2022notebooks/2022_0627sala_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import IPython
isColab = 'google.colab' in str(IPython.get_ipython())
if isColab:

    import nltk
    nltk.download('wordnet')    
    nltk.download('omw-1.4')    

    import os
    if os.path.exists('ccap'):
        import shutil
        shutil.rmtree('ccap')
    !git clone https://github.com/project-ccap/ccap.git

   
# try:    
#     import japanize_matplotlib
# except ImportError:
#     !pip install japanize_matplotlib
    
from ccap import salaDataset
sala = salaDataset()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from torchvision import transforms

_image_size = 224
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_trans = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(_image_size),
    #transforms.RandomRotation(degrees=(-10,10))
    transforms.RandomAffine(degrees=(-15,+15), scale=(0.6,1.4)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

val_trans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(_image_size),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

In [None]:
from torch.utils.data import Dataset

In [None]:
sala(0)

In [None]:
import os
from glob import glob
import PIL

#from torchvision.datasets.folder import ImageFolder, default_loader
#from torchvision.datasets.utils import download_url, check_integrity

class SALADataset(torch.utils.data.Dataset):
    """
    SALA の画像データ
    """

    def __init__(
        self,
        sala=salaDataset(),
        #root_path:str='./ccap/data',  #/sala_imgs',
        transform=train_trans,
    ):
        super().__init__()

        data = {}
        for idx in range(sala.__len__()):
            img_fname, label = sala(idx)
            data[idx] = {'fname': img_fname,
                         'label': label,
                        }
        self.data = data
        
        self.idx2name = list(data.keys())
        self.name2idx = {x:i for i, x in enumerate(self.idx2name)}
        self.transform = transform
            

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, x):
        name = self.idx2name[x]
        img_fname = self.data[name]['fname']
        img = PIL.Image.open(img_fname)
        _img = train_trans(img)
        return _img, x


In [None]:
train_dataset = SALADataset()
val_dataset = SALADataset(transform=val_trans)
print(train_dataset.__getitem__(0)[0].size())


In [None]:
import matplotlib.pyplot as plt
# https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html#sphx-glr-auto-examples-plot-visualization-utils-py

N = np.random.choice(train_dataset.__len__())
img = train_dataset.__getitem__(N)[0]
_img = img.permute(1,2,0).clone()
_img = img.permute(1,2,0).clone().numpy()
print(f'_img.shape:{_img.shape}', 
      f'_img.max():{_img.max():.2f}'
      f' _img.min():{ _img.min():.2f}')
#_img = torchvision.transforms.functional.to_pil_image(_img)

plt.imshow(_img)
plt.show()

In [None]:
from torch.utils.data import DataLoader
batch_size = 32

train_dl = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
)

val_dl = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
)

In [None]:
from torchvision import models

model = models.resnet18(pretrained=True)

In [None]:
try:
    import torchsummary
except ImportError:
    !pip install torhcsummary
# import torchsummary
# torchsummary.summary(model, (3, 224, 224), device="cpu")

In [None]:
# %load my_train_helper.py
def get_trainable(model_params):
    return (p for p in model_params if p.requires_grad)


def get_frozen(model_params):
    return (p for p in model_params if not p.requires_grad)


def all_trainable(model_params):
    return all(p.requires_grad for p in model_params)


def all_frozen(model_params):
    return all(not p.requires_grad for p in model_params)


def freeze_all(model_params):
    for param in model_params:
        param.requires_grad = False


In [None]:
# Freeze all parameters manually
for param in model.parameters():
    param.requires_grad = False
    
# Or use our convenient functions from before
freeze_all(model.parameters())
assert all_frozen(model.parameters())    

In [None]:
#help(transforms.RandomRotation)
#help(transforms.RandomAffine)
#transforms.RandomAffine(degrees=[-15,-5,5,10], scale=(0.8,1.2))

最終直下層を入れ替えて， `requires_grad=True` に設定

In [None]:
n_classes = train_dataset.__len__()
model.fc = nn.Linear(512, n_classes)

In [None]:
def get_model(n_classes=n_classes):
    model = models.resnet18(pretrained=True)
    freeze_all(model.parameters())
    model.fc = nn.Linear(512, n_classes)
    model = model.to(device)
    return model

model = get_model()

In [None]:
#model;

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = torch.optim.Adam(
    get_trainable(model.parameters()),
    lr=0.001,
)

In [None]:
%%time
from tqdm.notebook import tqdm

N_EPOCHS = 10
for epoch in range(N_EPOCHS):
    
    model.train()

    total_loss, n_correct, n_samples = 0.0, 0, 0
    for batch_i, (X, y) in enumerate(train_dl):
        X, y = X.to(device), y.to(device)
        
        optimizer.zero_grad()
        y_ = model(X)
        loss = criterion(y_, y)
        loss.backward()
        optimizer.step()
        
        _, y_label_ = torch.max(y_, 1)
        n_correct += (y_label_ == y).sum().item()
        total_loss += loss.item() * X.shape[0]
        n_samples += X.shape[0]
    
    print(
        f"エポック {epoch+1:2d}/{N_EPOCHS:2d} "
        f"訓練損失: {total_loss / n_samples:.3f} "
        f"訓練精度: {n_correct / n_samples * 100:.2f}%"
    )
    
    
    model.eval()
    total_loss, n_correct, n_samples = 0.0, 0, 0
    with torch.no_grad():
        for X, y in val_dl:
            X, y = X.to(device), y.to(device)
            y_ = model(X)
            
            _, y_label_ = torch.max(y_, 1)
            n_correct += (y_label_ == y).sum().item()
            loss = criterion(y_, y)
            total_loss += loss.item() * X.shape[0]
            n_samples += X.shape[0]

    print(
        f"エポック {epoch+1:2d}/{N_EPOCHS:2d} "
        f"検証損失: {total_loss / n_samples:.3f} "
        f"検証精度: {n_correct / n_samples * 100:.2f}%"
    )
    