# データの読み込み

In [None]:
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
import torchvision.transforms as transforms
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import skorch
import pandas as pd
import sklearn
import csv
import os

In [None]:
from __future__ import print_function
from torch.utils import data
from models import *
from utils import Visualizer, view_model
import random
import time
from config.config import Config
from torch.nn import DataParallel
from torch.optim.lr_scheduler import StepLR
from test import *

In [None]:
DATA_FOLDER = ''
csv_path = os.path.join(DATA_FOLDER, 'resized_data.csv')
datalist = pd.read_csv(csv_path, names=["img_path", "l_class", 's_class'])
datalist.head()

In [None]:
dfs = datalist.drop(['l_class'], axis=1)
dfs.groupby('s_class').count()

In [None]:
dfs.img_path.count()

In [None]:
heatmap_df = pd.read_csv("", names=["img_path", "l_class", 's_class'])
heatmap_df = heatmap_df.drop(['l_class'], axis=1)
heatmap_df

In [None]:
from sklearn.preprocessing import LabelEncoder

In [None]:
le = LabelEncoder() 
le.fit(dfs.s_class) 
dfs["labels"] = le.transform(dfs.s_class) 
dfs.groupby('labels')

In [None]:
cor_table = dfs.groupby('labels').s_class.unique() 
cor_table = pd.DataFrame(cor_table) 
num_s_class = len(cor_table)
cor_table

In [None]:
heatmap_df['labels'] = le.transform(heatmap_df.s_class)
heatmap_table = heatmap_df.groupby('labels').s_class.unique()
heatmap_df= heatmap_df.drop(['s_class'], axis=1)
heatmap_table

In [None]:
dfs = dfs.drop(['s_class'], axis=1)
dfs

# trainデータ, testデータの分割

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_data, test_data = train_test_split(dfs, test_size=0.2, random_state=42, stratify=dfs.labels)

In [None]:
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42, stratify=train_data.labels)

In [None]:
#画像の前処理を定義
data_transforms = {
    'data': transforms.Compose([
        transforms.Resize(224), #いらない
        transforms.ToTensor(),
        transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
}
#正規化をしない処理
to_tensor_transforms = transforms.Compose([
    transforms.Resize(224), #いらない
    transforms.ToTensor()
])

In [None]:
class CustomDataset(torch.utils.data.Dataset):
        
    def __init__(self, dataframe, root_dir, transform=None):
        #前処理クラスの指定
        self.transform = transform
        #pandasでcsvデータの読み出し
        #画像とラベルの一覧を保持するリスト
        self.images = np.array(dataframe.img_path).tolist()
        self.labels = np.array(dataframe.labels).tolist()
        self.root_dir = root_dir
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        #dataframeから画像へのパスとラベルを読み出す
        label = self.labels[idx]
        img = self.images[idx]
        #画像の読み込み
        with open(img, 'rb') as f:
            image = Image.open(f)
            image = image.convert('RGB')
            image = image.resize((224,224))
        #画像への処理
        if self.transform is not None:
            image = self.transform(image)
            
        return image, label

In [None]:
train_set = CustomDataset(dataframe=train_data, root_dir="../data/insta_frames", transform=data_transforms['data'])
val_set = CustomDataset(dataframe=val_data, root_dir="../data/insta_frames", transform=data_transforms['data'])
test_set = CustomDataset(dataframe=test_data, root_dir="../data/insta_frames", transform=data_transforms['data'])

In [None]:
# DataLoaderのcollate_fnはバッチ内のtensorのshapeをすべて同じにする必要がある
# 自分で指定してエラーが起きないようにする
def my_collate_fn(batch):
    # datasetの出力が
    # [image, target] = dataset[batch_idx]
    # の場合.
    images = []
    labels = []
    for image, label in batch:
        images.append(image)
        labels.append(label)
    images = torch.stack(images,dim=0)
    return images, labels

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=16, shuffle=True, num_workers=6)
val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=16, shuffle=False, num_workers=6)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=16, shuffle=False, num_workers=6)

In [None]:
heatmap_set = CustomDataset(dataframe=heatmap_df, root_dir="../data/inpainting_data", transform=data_transforms['data'])
heatmap_loader = torch.utils.data.DataLoader(dataset=heatmap_set, batch_size=1, shuffle=False, num_workers=6)

# ネットワークの定義
vgg16 finetuning

In [None]:
import torchvision.models as models

In [None]:
num_classes = num_s_class
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = models.vgg16(pretrained=True, progress=True)
net

In [None]:
#パラメータ凍結と採取層クラス数変更
for param in net.parameters():
    param.requires_grad = False
#最終層をnum_s_classクラス用に変更
num_ftrs = net.classifier[6].in_features
opt = Config()
# net.avgpool = None
net.classifier[6] = nn.Linear(num_ftrs, 512)
metric_fc = ArcMarginProduct(512, num_classes, s=10, m=0.1, easy_margin=opt.easy_margin)
#最適化関数
criterion = FocalLoss(gamma=2)
optimizer = optim.SGD([{'params': net.parameters()}, {'params': metric_fc.parameters()}],lr=opt.lr, weight_decay=opt.weight_decay)
scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1)
net = net.to(device)
metric_fc.to(device)
net

#学習率の変更
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=)

# 学習の実行

In [None]:
#Cross Validationを行いたい...
#Early Stopping を行いたい

num_epochs = 50
train_loss_list = []
train_acc_list = []
val_loss_list = []
val_acc_list = []

start = time.time()
for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    
    #train
    net.train()
    for i, (images, labels) in enumerate(train_loader):
        #view()での変換をしない
        images, labels = images.to(device), labels.to(device).long()
        
        features = net(images)
        outputs = metric_fc(features, labels)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        iters = epoch * len(train_loader) + i

        if iters % opt.print_freq == 0:
            outputs = outputs.data.cpu().numpy()
            outputs = np.argmax(outputs, axis=1)
            labels = labels.data.cpu().numpy()
            # print(output)
            # print(label)
            acc = np.mean((outputs == labels).astype(int))
            speed = opt.print_freq / (time.time() - start)
            time_str = time.asctime(time.localtime(time.time()))
            print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.format(time_str, epoch, i, speed, loss.item(), acc))
            if opt.display:
                visualizer = Visualizer()
                visualizer.display_current_results(iters, loss.item(), name='train_loss')
                visualizer.display_current_results(iters, acc, name='train_acc')

            start = time.time()
            
#         train_loss += loss.item()
# #         train_acc += (outputs.max(1)[1]==labels).sum().item()
#         outputs = outputs.data.cpu().numpy()
#         outputs = np.argmax(outputs, axis=1)
#         labels = labels.data.cpu().numpy()
#         acc = np.mean((outputs==labels).astype(int))
#         train_acc += acc
        
#     avg_train_loss = train_loss/len(train_loader.dataset)
#     avg_train_acc = train_acc/len(train_loader.dataset)
    
#     #validation
#     net.eval()
#     with torch.no_grad():
#         for images, labels in val_loader:
#             #view()での変換をしない
#             images = images.to(device)
#             labels = labels.to(device).long()
#             features = net(images)
#             outputs = metric_fc(features, labels)
#             loss = criterion(outputs, labels)
#             val_loss += loss.item()
# #             val_acc += (outputs.max(1)[1]==labels).sum().item()
#             outputs = outputs.data.cpu().numpy()
#             outputs = np.argmax(outputs, axis=1)
#             labels = labels.data.cpu().numpy()
#             print(outputs==labels)
#             acc = np.mean((outputs==labels).astype(int))
#             val_acc += acc
#     avg_val_loss = val_loss/len(val_loader.dataset)
#     avg_val_acc = val_acc/len(val_loader.dataset)
    
#     print('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'.format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
#     train_loss_list.append(avg_train_loss)
#     train_acc_list.append(avg_train_acc)
#     val_loss_list.append(avg_val_loss)
#     val_acc_list.append(avg_val_acc)

# train, validationのloss acc のグラフを作成

In [None]:
plt.figure()
plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss')
plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Training and validation loss')
plt.grid()

plt.figure()
plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc')
plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='test_acc')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('acc')
plt.title('Training and validation acc')
plt.grid()

# Grad-CAMの定義

In [None]:
class GradCam:
    def __init__(self, model):
        self.model = model.eval()
        self.feature = None
        self.gradient = None
    
    def save_gradient(self, grad):
        self.gradient = grad
    
    def __call__(self, x):
        image_size = (x.size(-1), x.size(-2))
        feature_maps =[]
        
        for i in range(x.size(0)):
            img = x[i].data.cpu().numpy() #GPU上のTensorはcpuに移さないとnumpyに変換できない
            img = img - np.min(img)
            if np.max(img) != 0:
                img = img / np.max(img)
            
            feature = x[i].unsqueeze(0)
            
            for name, module in self.model.named_children():
                if name == 'clasifier':
                    feature = feature.view(feature.size(0), -1)
                feature = module(feature)
                if name == 'features':
                    feature.register_hook(self.save_gradient)
                    self.feature = feature
                    
            classes = F.sigmoid(feature)
            one_hot, _ = classes.max(dim=-1)
            self.model.zero_grad()
            one_hot.backward()
            
            weight = self.gradient.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
            
            mask = F.relu((weight*self.feature).sum(dim=1)).squeeze(0)
            mask = cv2.resize(mask.data.cpu().numpy(), image_size)
            mask = mask - np.min(mask)
            
            if np.max(mask) != 0:
                mask = mask/np.max(mask)
                
            feature_map = np.float32(cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET))
            cam = feature_map + np.float32((np.uint8(img.transpose((1,2,0))*225)))
            cam = cam - np.min(cam)
            
            if np.max(cam) != 0:
                cam = cam/np.max(cam)
                
            feature_maps.append(transforms.ToTensor()(cv2.cvtColor(np.uint8(225*cam), cv2.COLOR_BGR2RGB)))
            
        feature_maps = torch.stack(feature_maps)
        
        return feature_maps
                

In [None]:
for i in range(len(cam_test_img_path)):
    #入力画像の読み込み
    cam_test_img = Image.open(cam_test_img_path[i])
    cam_img_tensor = (data_transforms['data']((cam_test_img))).unsqueeze(dim=0)
    
    cam_img_tensor = cam_img_tensor.to(device)
    
    img_size = cam_test_img.size
    #grad-camによる予測根拠可視化
    gradcam = GradCam(net)
    
    feature_image = gradcam(cam_img_tensor).squeeze(dim=0)
    feature_image = transforms.ToPILImage()(feature_image)
    
    pred_idx = net(cam_img_tensor).max(1)[1]
                      
    save_dir = '../data/gradcam_img/VGG16/'+s_classlist[i]
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)	# Make a directory
    #保存先ディレクトリ名はその画像のクラス，画像の予測値を画像の名前に書き込む
    cv2.imwrite(save_dir+'/heatmap_pred_'+s_classlist[pred_idx]+'.jpg', superimposed_img)
    print('Saved: ', save_dir+'/heatmap_pred_'+s_classlist[pred_idx]+'.jpg')