In [1]:
import torch

import os
import sys
import random
from typing import Any
import torch
from torchmetrics import Accuracy


import torch
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy

from tqdm import tqdm

sys.path.append("/home/tak/IBT/Image-back-translation")

In [2]:
class MixAugLitModule(LightningModule):
    def __init__(
        self,
        net: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        num_classes: int,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.net = net

        self.criterion = torch.nn.BCEWithLogitsLoss()

        # self.train_acc = Accuracy(task="multilabel", num_labels=num_classes, average='micro')
        # self.val_acc = Accuracy(task="multilabel", num_labels=num_classes, average='micro')
        # self.test_acc = Accuracy(task="multilabel", num_labels=num_classes, average='micro')

        self.train_acc_top1 = Accuracy(task="multilabel", num_labels=num_classes, top_k=1, average='micro')
        self.train_acc_top5 = Accuracy(task="multilabel", num_labels=num_classes, top_k=5, average='micro')

        self.val_acc_top1 = Accuracy(task="multiclass", num_classes=num_classes, top_k=1, average='micro')
        self.val_acc_top5 = Accuracy(task="multiclass", num_classes=num_classes, top_k=5, average='micro')

        self.test_acc_top1 = Accuracy(task="multiclass", num_classes=num_classes, top_k=1, average='micro')
        self.test_acc_top5 = Accuracy(task="multiclass", num_classes=num_classes, top_k=5, average='micro')
        print(f"num_classes: {num_classes}")
        

        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        self.val_acc_best = MaxMetric()

    def forward(self, x: torch.Tensor):
        return self.net(x)
    
    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.val_loss.reset()
        self.val_acc_top1.reset()
        self.val_acc_top5.reset()
        self.val_acc_best.reset()

    def model_step(self, batch: Any):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        
        # For multi-label classification
        preds = torch.sigmoid(logits)
        
        return loss, preds, y


    def training_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.model_step(batch)

        # Convert soft labels to hard labels for accuracy calculation
        hard_targets = torch.zeros_like(targets)
        hard_targets[torch.arange(targets.size(0)), targets.argmax(1)] = 1

        # update and log metrics
        self.train_loss(loss)

        self.train_acc_top1(preds, hard_targets)
        self.train_acc_top5(preds, hard_targets)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc_top1", self.train_acc_top1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc_top5", self.train_acc_top5, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self):
        pass

    def validation_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)

        _, targets_indices = torch.max(targets, dim=1)
        self.val_acc_top1(preds, targets_indices)
        self.val_acc_top5(preds, targets_indices)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc_top1", self.val_acc_top1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc_top5", self.val_acc_top5, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self):
        acc = self.val_acc_top1.compute()  # get current val acc
        self.val_acc_best(acc)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/acc_best", self.val_acc_best.compute(), prog_bar=True)

    def test_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)

        _, preds_indices = torch.max(preds, dim=1)
        _, targets_indices = torch.max(targets, dim=1)


        self.test_acc_top1(preds_indices, targets_indices)
        self.test_acc_top5(preds_indices, targets_indices)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc_top1", self.test_acc_top1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc_top5", self.test_acc_top5, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_end(self):
        pass

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
        """
        optimizer = self.hparams.optimizer(params=self.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

In [3]:
import os
from pathlib import Path
from typing import Dict



from src.data.components import MixAugDataset, CMIAImageFolder

import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from torchvision.datasets.folder import default_loader

class ConcatDataModule(LightningDataModule):
    def __init__(self, root_path: str, train_dir: str, aug_dir: str, val_dir: str, batch_size: int, num_classes=1000, concat=True, aug_num=1):
        super().__init__()
        self.root_path = root_path
        self.train_dir = os.path.join(self.root_path, train_dir)
        self.aug_dir = os.path.join(self.root_path, aug_dir)
        self.val_dir = os.path.join(self.root_path, val_dir)
        # self.test_dir = os.path.join(self.root_path, 'test')

        self.concat = concat
        self.batch_size = batch_size
        self.num_classes = num_classes

        self.train_transform = transforms.Compose([
                                transforms.Resize((512, 512)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                    std=[0.229, 0.224, 0.225]),
                                # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
                                                    ])
        
        self.val_transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                std=[0.229, 0.224, 0.225]),
        ])

    def one_hot(self, x, num_classes, on_value=1., off_value=0.):
        x = x.long().view(-1, 1)
        return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)

    def collate_fn(self, batch):
        images, targets = list(zip(*batch))
        targets = self.one_hot(torch.tensor(targets, dtype=torch.int64), self.num_classes)
        images = torch.stack(images)
        return images, targets

    def setup(self, stage=None):
        # Split the dataset into train, val, and test sets
        # Create instances of the CustomDataset for each split
        self.original_dataset = datasets.ImageFolder(self.train_dir, self.train_transform)
        self.aug_dataset = datasets.ImageFolder(self.aug_dir, self.train_transform)

        if self.concat:
            self.train_dataset = ConcatDataset([self.original_dataset, self.aug_dataset])
        else:
            self.train_dataset = self.original_dataset

        self.val_dataset = datasets.ImageFolder(self.val_dir, self.val_transform)
        # self.test_dataset = datasets.ImageFolder(self.test_dir, self.val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=self.collate_fn)

    # def test_dataloader(self):
    #     return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=self.val_collate_fn)


In [4]:
concat_data_module = ConcatDataModule(root_path='/data2/tak/1000way', train_dir='train_10percent', aug_dir='train_10percent_IBT', val_dir='val_formatted', batch_size=128, num_classes=1000, concat=False)
concat_data_module.setup()
val_loader = concat_data_module.val_dataloader()

In [5]:
check_point_path = "/nvme_data1/tak/wandb/concat/runs/2023-10-31_14-36-46/checkpoints/epoch_012.ckpt"
model = MixAugLitModule.load_from_checkpoint(check_point_path)
model.eval()

device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = model.to(device)

  rank_zero_warn(


num_classes: 1000
device: cuda:7


In [6]:
acc_top1 = Accuracy(task="multiclass", num_classes=1000, top_k=1, average='micro').to(device)
acc_top5 = Accuracy(task="multiclass", num_classes=1000, top_k=5, average='micro').to(device)

In [7]:
# 검증 과정
val_loss = 0.0
with torch.no_grad():
    for batch in tqdm(val_loader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        logits = model(x)
        
        # 손실 계산
        loss = model.criterion(logits, y)
        val_loss += loss.item()

        # 예측 결과를 확률로 변환
        preds = torch.sigmoid(logits)

        # 타겟 인덱스 추출
        _, targets_indices = torch.max(y, dim=1)

        # 정확도 업데이트
        acc_top1(preds, targets_indices)
        acc_top5(preds, targets_indices)

# 평균 손실 계산
val_loss /= len(val_loader)

# 성능 메트릭 출력
print(f"Validation Loss: {val_loss}")
print(f"Validation Top-1 Accuracy: {acc_top1.compute()}")
print(f"Validation Top-5 Accuracy: {acc_top5.compute()}")

100%|██████████| 391/391 [06:25<00:00,  1.01it/s]

Validation Loss: 0.005966499004789325
Validation Top-1 Accuracy: 0.12479999661445618
Validation Top-5 Accuracy: 0.28859999775886536



