<a href="https://colab.research.google.com/github/Lukas4319/Animal_classification/blob/main/OpenSource_Software.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb

# 라이브러리 호출 및 GPU 설정

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append("/content/drive/MyDrive/Colab Notebooks")
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from torch.optim.lr_scheduler import StepLR
import wandb
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

 # Hyper Parameter 및 경로 설정

In [41]:
BATCH_SIZE = 32
LR = 0.0001
LR_STEP = 5
LR_GAMMA = 0.9
EPOCH = 20
criterion = nn.CrossEntropyLoss()
new_model_train = True
model_type = "resnet18"
dataset = "Animal10"
save_model_path = f"/content/drive/MyDrive/Colab Notebooks/result/{model_type}_{dataset}.pt"
save_history_path = f"/content/drive/MyDrive/Colab Notebooks/result/{model_type}_history_{dataset}.pt"

# Data Load

In [None]:
!pip install gdown==4.6.0

In [None]:
!gdown https://drive.google.com/uc?id=#data_id

In [None]:
!unzip /content/Animals10.zip

# Data preprocessing

In [34]:
transform_train = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor()])
transform_val = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor()])
transform_test = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor()])

In [35]:
train_DS = torchvision.datasets.ImageFolder(root = "/content/Animals10/train_DS", transform = transform_train)
val_DS = torchvision.datasets.ImageFolder(root = "/content/Animals10/val_DS", transform = transform_val)
test_DS = torchvision.datasets.ImageFolder(root = "/content/Animals10/test_DS", transform = transform_test)

In [36]:
train_DL = DataLoader(train_DS, batch_size = BATCH_SIZE, shuffle = True)
val_DL = DataLoader(val_DS, batch_size = BATCH_SIZE, shuffle = True)
test_DL = DataLoader(test_DS, batch_size = BATCH_SIZE, shuffle = True)

# Train & Test

In [19]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device=DEVICE):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = None
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "train_acc": [],
            "val_acc": [],
        }
        self.best_loss = float("inf")

    def set_scheduler(self, step_size, gamma=0.1):
        self.scheduler = StepLR(self.optimizer, step_size=step_size, gamma=gamma)

    def _run_epoch(self, loader, is_train=True):
        mode = "train" if is_train else "val"
        self.model.train() if is_train else self.model.eval()

        running_loss = 0
        correct = 0
        total = 0

        for x_batch, y_batch in tqdm(loader, desc=f"{mode} Epoch", leave=False):
            x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
            with torch.set_grad_enabled(is_train):
                y_pred = self.model(x_batch)
                loss = self.criterion(y_pred, y_batch)

                if is_train:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

            running_loss += loss.item() * x_batch.size(0)
            correct += (y_pred.argmax(1) == y_batch).sum().item()
            total += x_batch.size(0)

        avg_loss = running_loss / total
        accuracy = correct / total * 100
        return avg_loss, accuracy

    def train(self, epochs, save_model_path, log_wandb=False):
        for epoch in range(epochs):
            start_time = time.time()
            current_lr = self.optimizer.param_groups[0]["lr"]
            print(f"Epoch {epoch + 1}/{epochs}, LR: {current_lr}")

            train_loss, train_acc = self._run_epoch(self.train_loader, is_train=True)
            val_loss, val_acc = self._run_epoch(self.val_loader, is_train=False)

            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)
            self.history["train_acc"].append(train_acc)
            self.history["val_acc"].append(val_acc)

            if val_loss < self.best_loss:
                self.best_loss = val_loss
                torch.save(self.model.state_dict(), save_model_path)

            if self.scheduler:
                self.scheduler.step()

            epoch_time = time.time() - start_time
            print(
                f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%, Time: {epoch_time:.2f}s"
            )

            if log_wandb:
                wandb.log({
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "train_acc": train_acc,
                    "val_acc": val_acc,
                    "epoch": epoch,
                })

        return self.history

    def evaluate(self, test_loader):
        test_loss, test_acc = self._run_epoch(test_loader, is_train=False)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")
        return test_loss, test_acc

# Model 생성

In [37]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
                                      nn.BatchNorm2d(inner_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(inner_channels, inner_channels * self.expansion, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(inner_channels))
        self.projection = projection
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        residual = self.residual(x)

        if self.projection is not None:
            shortcut = self.projection(x) # 점선 연결
        else:
            shortcut = x # 실선 연결

        out = self.relu(residual + shortcut)
        return out

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, inner_channels, 1, bias=False),
                                      nn.BatchNorm2d(inner_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(inner_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
                                      nn.BatchNorm2d(inner_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(inner_channels, inner_channels * self.expansion, 1, bias=False),
                                      nn.BatchNorm2d(inner_channels * self.expansion))

        self.projection = projection
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.residual(x)

        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x

        out = self.relu(residual + shortcut)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_block_list, num_classes = 10, zero_init_residual = True):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.stage1 = self.make_stage(block, 64, num_block_list[0], stride=1)
        self.stage2 = self.make_stage(block, 128, num_block_list[1], stride=2)
        self.stage3 = self.make_stage(block, 256, num_block_list[2], stride=2)
        self.stage4 = self.make_stage(block, 512, num_block_list[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(nn.Linear(512 * block.expansion, 1024),
                                nn.ReLU(),
                                nn.Linear(1024, 512),
                                nn.ReLU(),
                                nn.Linear(512, 256),
                                nn.ReLU(),
                                nn.Linear(256, 128),
                                nn.ReLU(),
                                nn.Linear(128, 64),
                                nn.ReLU(),
                                nn.Linear(64, 32),
                                nn.ReLU(),
                                nn.Linear(32, 16),
                                nn.ReLU(),
                                nn.Linear(16, num_classes))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, block):
                    nn.init.constant_(m.residual[-1].weight, 0)

    def make_stage(self, block, inner_channels, num_blocks, stride = 1):

        if stride != 1 or self.in_channels != inner_channels * block.expansion:
            projection = nn.Sequential(
                nn.Conv2d(self.in_channels, inner_channels * block.expansion, 1, stride=stride, bias=False),
                nn.BatchNorm2d(inner_channels * block.expansion))
        else:
            projection = None

        layers = []
        layers += [block(self.in_channels, inner_channels, stride, projection)]
        self.in_channels = inner_channels * block.expansion
        for _ in range(1, num_blocks):
            layers += [block(self.in_channels, inner_channels)]

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [38]:
def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)

def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

In [None]:
exec(f"model = {model_type}().to(DEVICE)")
print(model)
x_batch, _ = next(iter(train_DL))
print(model(x_batch.to(DEVICE)).shape)

# Model Train

In [None]:
if new_model_train:
    optimizer = optim.AdamW(model.parameters(), lr = LR)
    trainer = Trainer(model=model,
                      train_loader=train_DL,
                      val_loader=val_DL,
                      criterion=criterion,
                      optimizer=optimizer,
                      device=DEVICE)
    trainer.set_scheduler(step_size=LR_STEP, gamma=LR_GAMMA)
    history = trainer.train(epochs=EPOCH,
                            save_model_path=save_model_path,
                            log_wandb=False)
elif new_model_train:
    optimizer = optim.AdamW(model.parameters(), lr = LR)
    trainer = Trainer(model=load_model,
                      train_loader=train_DL,
                      val_loader=val_DL,
                      criterion=criterion,
                      optimizer=optimizer,
                      device=DEVICE)
    trainer.set_scheduler(step_size=LR_STEP, gamma=LR_GAMMA)
    history = trainer.train(epochs=EPOCH,
                            save_model_path=save_model_path,
                            log_wandb=False)

# Model Load

In [None]:
load_model = model.load_state_dict(torch.load(save_model_path))

# Model Test

In [None]:
test_loss, test_acc = trainer.evaluate(test_DL)