In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import numpy as np
import pandas as pd
import cv2
import random
import os
import evaluate
from tqdm import tqdm
import warnings
warnings.filterwarnings(action = "ignore")

from sklearn.model_selection import train_test_split

CFG = {
    "LEARNING_RATE": 1e-4,
    "EPOCHS": 30,
    "BATCH_SIZE": 32,
    "SEED": 42,
    "DEVICE": torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
    "MODEL":  "microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft"
}

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# torch.cuda.set_device(1)
# torch.cuda.current_device()

  from .autonotebook import tqdm as notebook_tqdm


1

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG["SEED"])

In [3]:
data = pd.read_csv("train_.csv")
# Class Name 지정
classes = np.unique(data["label"])
class_name = {name: i for i, name in enumerate(classes)}
trainset, valset, _, _ = train_test_split(data, data["label"], test_size = 0.1, stratify = data["label"], random_state = 42)
trainset = trainset.reset_index()
trainset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
valset = valset.reset_index()
valset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)

In [4]:
image_processor = AutoImageProcessor.from_pretrained(
    pretrained_model_name_or_path = CFG["MODEL"]
)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size = (image_processor.size["height"], image_processor.size["width"])),
    transforms.Normalize(
        mean = image_processor.image_mean,
        std = image_processor.image_std
    )
])

class ImageSet(Dataset):
    def __init__(self, img_low, img_high, transform = None, class_name = None, label = None):
        self.img_low = img_low
        self.img_high = img_high
        self.label = label
        self.transform = transform
        self.class_name = class_name
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        images_low = self.img_low[idx]
        images_high = self.img_high[idx]
        label = self.label[idx]
        imgs_low = cv2.imread(images_low)
        imgs_low = cv2.cvtColor(imgs_low, cv2.COLOR_BGR2RGB)
        imgs_high = cv2.imread(images_high)
        imgs_high = cv2.cvtColor(imgs_high, cv2.COLOR_BGR2RGB)
        # if self.transform:
        #     image_low = self.transform(imgs_low)
        #     image_high = self.transform(imgs_high)
        label = class_name[label]
        return imgs_low, label

def collator(data, transform):
    imgs, labels = zip(*data)
    pixel_values = torch.stack([transform(img) for img in imgs])
    labels = torch.tensor([label for label in labels])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(eval_pred):
    metric = evaluate.load("f1")
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis = 1)
    macro_f1 = metric.compute(
        predictions = predictions, references = labels, average = "macro"
    )
    return macro_f1

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [5]:
def model_init(classes, class_name):
    model = AutoModelForImageClassification.from_pretrained(
        CFG["MODEL"],
        num_labels = len(classes),
        id2label = {idx: label for label, idx in class_name.items()},
        label2id = class_name,
        ignore_mismatched_sizes = True
    )
    return model

In [6]:
trainset = ImageSet(img_low = trainset["img_path"], img_high = trainset["upscale_img_path"], transform = transform, class_name = class_name, label = trainset["label"])
validset = ImageSet(img_low = valset["img_path"], img_high = valset["upscale_img_path"], transform = transform, class_name = class_name, label = valset["label"])

In [7]:
device = CFG["DEVICE"]
args = TrainingArguments(
    output_dir = "./swin-transformer",
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    learning_rate = 1e-5,
    per_device_train_batch_size = 16,
    per_device_eval_batch_size = 16,
    gradient_accumulation_steps = 2,
    num_train_epochs = 50,
    weight_decay = 1e-4,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    logging_dir="logs",
    logging_steps=400,
    seed = CFG["SEED"],
    warmup_ratio = 0.1,
    label_smoothing_factor = 1e-3,
    remove_unused_columns = False,
)
args.device

device(type='cuda', index=0)

In [None]:
trainer = Trainer(
    model = model_init(classes, class_name),
    args = args,
    train_dataset = trainset,
    eval_dataset = validset,
    data_collator = lambda x: collator(x, transform),
    compute_metrics = compute_metrics,
    tokenizer = image_processor,
)
trainer.train()

Some weights of Swinv2ForImageClassification were not initialized from the model checkpoint at microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1536]) in the checkpoint and torch.Size([25, 1536]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([25]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,F1
0,3.058,1.549684,0.721044
2,0.805,0.187635,0.942719
2,0.2053,0.164334,0.955607


Downloading builder script: 100%|██████████████████████████████████████████████████████████████| 6.77k/6.77k [00:00<00:00, 3.31MB/s]
