# Library

In [None]:
# 드라이브 마운트
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.5.4 

In [3]:
from __future__ import print_function
from __future__ import division

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.transforms as T

import os
import pandas as pd
import numpy as np

import json
import csv
import cv2 # for image load

from PIL import Image

import albumentations as A
import albumentations.pytorch
import torchvision.transforms as transforms
from torch.utils.data import random_split

import timm

import random
import plotly.express as px # for grap
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from sklearn.metrics import f1_score # for f1 score

from tqdm import tqdm # for progress bar

In [None]:
# avail_pretrained_models = timm.list_models(pretrained=True)
# len(avail_pretrained_models), avail_pretrained_models

# Get Dataset

In [11]:
# MY_DIR: 자기 이미지 압축파일 있는 경로 넣기
# MY_DIR = '/content/drive/MyDrive/ML/DATA302/dataset'
MY_DIR = '/content/drive/MyDrive/3d'

FAKE_IMG_PATH = '/content/TestSet/'
REAL_IMG_PATH = '/content/REAL/'

In [None]:
os.system(f"unzip {MY_DIR}/TestSet.zip -d /content/")
!mkdir "tars"
os.system(f"tar -xvf {MY_DIR}/ILSVRC2012_img_train_t3.tar -C /content/tars")

!mkdir "REAL"
dir_list = os.listdir("/content/tars")
for dir in dir_list:
  os.system(f"tar -xvf /content/tars/{dir} -C /content/REAL")

In [7]:
# REAL 디렉토리 안에 아무것도 안뜨면 이 코드 한번 돌려보셈
# os.listdir("/content/REAL")

# Env

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

random.seed(42)

model_list = {
    "convnext_base_in22ft1k" : {
        "model" : "convnext_base_in22ft1k", 
        "input_size" : (224, 224),
        "classifier_in_feature" : 1024
        }
}

optimizer_list = ["adam", "sgd"]

model_name = "convnext_base_in22ft1k"

CONFIG = {
    "batch_size" : 16,
    "epoch" : 200,
    "learning_rate" : 1e-4,
    "input_size" : model_list[model_name]["input_size"],
    "backbone" : model_list[model_name]["model"],
    "classifier_in_feature": model_list[model_name]["classifier_in_feature"],
    "device" : device,
    "patience" : 10,
    "optimizer" : optimizer_list[0],
    "input_norm_mean" : IMAGENET_DEFAULT_MEAN,
    "input_norm_std" : IMAGENET_DEFAULT_STD
}

batch_size = CONFIG["batch_size"]
epochs = CONFIG["epoch"]
learning_rate = CONFIG["learning_rate"]
input_size = CONFIG["input_size"]
backbone = CONFIG["backbone"]
classifier_in_feature = CONFIG["classifier_in_feature"]
device = CONFIG["device"]
patience = CONFIG["patience"]
optimizer = CONFIG["optimizer"]
input_norm_mean = CONFIG["input_norm_mean"]
input_norm_std = CONFIG["input_norm_std"]

# Read the Data

In [27]:
class CustomDataset(Dataset):
    def __init__(self, FAKE_path_label, REAL_path_label, input_size, transform=None, isTest=False):

        self.FAKE_path_label = FAKE_path_label
        self.REAL_path_label = REAL_path_label
        self.input_size = input_size
        self.transform = transform
        self.isTest = isTest
        self.img_list = []

        for i in range(len(FAKE_path_label)):
            img_path = FAKE_path_label[i][0]
            image = Image.open(img_path)
            if self.transform:
                image = self.transform(image=np.array(image))['image']
            self.img_list.append((image, 0))

        for i in range(len(REAL_path_label)):
            img_path = REAL_path_label[i][0]
            image = Image.open(img_path)
            if self.transform:
                image = self.transform(image=np.array(image))['image']
            self.img_list.append((image, 1))

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        image, label = self.img_list[idx]

        return image, label #label type: int

In [None]:
FAKE_path_label = []
REAL_path_label = []

# Get Fake Data
with open(os.path.join(FAKE_IMG_PATH, "operations.csv"), "r") as operations_csv:
    row = csv.reader(operations_csv, delimiter=',')
    for i, col in enumerate(row):
        if col[6] == "FALSE" and int(col[1]) >= 224: # col[6]: label, col[1]: cropsize
            FAKE_path_label.append((FAKE_IMG_PATH + col[0], 0)) # col[0]: src, 0 for false


# Get Real Data
REAL_path_label = [(os.path.join(REAL_IMG_PATH, path), 1) for path in os.listdir(REAL_IMG_PATH)]

A_transform = A.Compose([
    A.PadIfNeeded(*input_size),
    albumentations.augmentations.crops.transforms.CenterCrop(*input_size),
    albumentations.augmentations.transforms.Normalize(
        mean=input_norm_mean, 
        std=input_norm_std
        ),
    albumentations.pytorch.transforms.ToTensorV2()
])

FAKE_path_label = random.sample(FAKE_path_label, k=3000) # 2000 for training, 1000 for test
REAL_path_label = random.sample(REAL_path_label, k=3000) # 2000 for training, 1000 for test

labeled_dataset = CustomDataset(FAKE_path_label[:2000], REAL_path_label[:2000], input_size=input_size, transform=A_transform, isTest=False)
# train_dataloader = DataLoader(
#     dataset=labeled_dataset,
#     batch_size=batch_size,
#     shuffle=True
# )


test_dataset = CustomDataset(FAKE_path_label[2000:], REAL_path_label[2000:], input_size=input_size, transform=A_transform, isTest=True)
test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False
)

def dataset_split(dataset, ratio):
    train_size = int(ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    return train_dataset, val_dataset

# Early Stopping

In [29]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        self.best_score = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}.")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Model

In [30]:
class Model(nn.Module):
    def __init__(self, backbone, classifier_in_feature, pretrained=True):
        super(Model, self).__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained)
        self.backbone.reset_classifier(0)
        self.classifier = nn.Linear(
                in_features=classifier_in_feature,
                out_features=2
            )
        
        for param in self.backbone.parameters():
            param.requires_grad=False

    def forward(self, input):
        output = self.backbone(input) # bs 1 1000
        output = self.classifier(output)
        return output


model = Model(backbone, classifier_in_feature, pretrained=True)
model = model.to(device)

loss_fn = nn.CrossEntropyLoss()

if optimizer == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 
elif optimizer == "sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [31]:
# print(model)

# Training and Validation Function

In [35]:
def train_model(device, model, optimizer, loss_fn, dataloader):
    model.train()
    correct = 0
    running_loss = 0
    pred_ans, gt_ans = [], []

    for batch, (x, gt) in enumerate(tqdm(dataloader)):
        x, gt = x.to(device), gt.to(device) 
        output = model(x)
        loss = loss_fn(output, gt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss
        
        pred = output.argmax(dim=1)
        pred_ans += pred.tolist()
        gt_ans += gt.tolist()
        correct += torch.sum(pred == gt.data)

    acc = 100 * correct / (len(dataloader) * batch_size)
    running_loss = running_loss/len(dataloader)

    return running_loss, acc, pred_ans, gt_ans

@torch.no_grad()
def valid_model(device, model, dataloader):
    model.eval()
    loss = 0
    running_loss = 0
    correct = 0
    pred_ans, gt_ans = [], []

    for batch, (x, gt) in enumerate(dataloader):
        x, gt = x.to(device), gt.to(device)

        output = model(x)
        loss = loss_fn(output, gt).item()

        pred = output.argmax(dim=1)

        running_loss += loss
        pred_ans += pred.tolist()
        gt_ans += gt.tolist()
        correct += torch.sum(pred == gt)
        
    acc = 100 * correct / (len(dataloader) * batch_size)
    running_loss = running_loss/len(dataloader)

    return running_loss, acc, pred_ans, gt_ans


# Training

In [None]:
# run = wandb.init(
#             project="3d",
#             config=CONFIG
#             ) # start a new run

In [None]:
MODEL_PATH = "best_model.pt"
val_loss_list = []
train_loss_list = []
cur_loss_min = 100000000

early_stopping = EarlyStopping(patience=patience, verbose=False)

for e in range(epochs):
    print(f"\nEpoch {e+1}")
    train_dataset, val_dataset = dataset_split(labeled_dataset, 0.8) # 8 : 2 = train : val
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )
    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    train_loss, train_acc, train_pred, train_gt = train_model(device, model, optimizer, loss_fn, train_dataloader)
    val_loss, val_acc, val_pred, val_gt = valid_model(device, model, val_dataloader)

    # wandb.log({"train loss": train_loss, 
    #            "train accuracy": train_acc, 
    #            "validation loss": val_loss, 
    #            "validation accuracy": val_acc})

    early_stopping(val_loss, model)
    print(f"train loss: {train_loss}\tvalidation loss: {val_loss}")

    if val_loss < cur_loss_min:
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Validation loss reduced ({cur_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
        cur_loss_min = val_loss

    if early_stopping.early_stop:
        print("Early stopping")
        break

    val_loss_list.append(val_loss)
    train_loss_list.append(train_loss)


# Test

In [19]:
@torch.no_grad()
def test_model(device, model, dataloader):
    pred_ans = []
    gt_ans = []
    corrects = 0

    model.eval()
    model = model.to(device)

    for x, gt in tqdm(dataloader):  # tqdm: progress bar 표시
        x, gt = x.to(device), gt.to(device)
        output = model(x)
        
        pred = output.argmax(dim=1)
        
        pred_ans += pred.tolist()
        gt_ans += gt.tolist()
        corrects += torch.sum(pred == gt.data)

    # accuracy 출력
    print(f"total accuracy: {corrects / len(dataloader.dataset)}")

    return pred_ans, gt_ans


In [20]:
# confusion matrix 시각화
import matplotlib.pyplot as plt
import itertools    # confusion matrix에서 사용
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(label, pred, target_names=None, labels=True):
    cm = confusion_matrix(label, pred)
    accuracy = np.trace(cm) / float(np.sum(cm))

    cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(9, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.colorbar()
    thresh = cm.max() / 2

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names)
        plt.yticks(tick_marks, target_names)

    if labels:
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, "{:,}".format(cm[i, j]), horizontalalignment="center",
                     color="white" if cm[i,j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

In [None]:
pred_ans, gt_ans = test_model(device, model, test_dataloader)
plot_confusion_matrix(pred_ans, gt_ans)