In [1]:
# datamodule
# 이미지
from pytorch_lightning import LightningDataModule
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as T

In [2]:
class DataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        transform = T.Compose([
            T.Resize(256),
            T.RandomCrop(224),
            T.ToTensor(), 
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
        ])
        dataset = ImageFolder('/data/datasets/k_fashion_detections/train', transform=transform)
        
        train_length = int(len(dataset)*.8)
        val_length = len(dataset) - train_length
        
        self.train_dataset, self.val_dataset = random_split(dataset=dataset, lengths=[train_length, val_length])
        self.batch_size = 128
        self.num_workers =4
        
    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset,
                         batch_size=self.batch_size,
                         num_workers=self.num_workers,
                         pin_memory=True,
                         drop_last=True)
    
    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset,
                         batch_size=self.batch_size,
                         num_workers=self.num_workers,
                         pin_memory=True,
                         drop_last=True)

In [4]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.models as models
from pytorch_lightning import LightningModule

from torchmetrics import Accuracy, F1, Precision, Recall
from ensemble_pr import *

In [5]:
class Network(LightningModule):
    def __init__(self,
                 ## 추가된 부분
                 # 데이터가 많으면 freeze=False로 해야된다.
                 freeze=False,
                 classes_num=20,
                 ## /추가된 부분
                 learning_rate=1e-4):
        super().__init__()
        ## 추가된 부분
        self.model = Ensemble()
        
        self.loss_fn = nn.NLLLoss()
        self.learning_rate = learning_rate
        
        self.acc = Accuracy()
        self.f1 = F1(num_classes=classes_num)
        self.pre = Precision(num_classes=classes_num)
        self.recall = Recall(num_classes=classes_num)
    
    def configure_optimizers(self):
        return optim.AdamW(lr=self.learning_rate,
                          params=self.model.parameters())
        
    def forward(self, x):
        return F.log_softmax(self.model(x),dim=1)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        # log
        # default => on_epoch=False, on_step=True
        self.log('train_loss', loss, on_epoch=True, on_step=True) # => train_loss_epoch, train_loss_step
        # /log
        return {'loss': loss}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        self.acc(logits, y)
        self.f1(logits, y)
        self.pre(logits, y)
        self.recall(logits, y)
        
        # log
        # default => on_epoch=True, on_step=False
        self.log('val_loss', loss)
        self.log('Accuracy', self.acc)
        self.log('F1', self.f1)
        self.log('Precision',self.pre)
        self.log('Recall', self.recall)
        
        # /log
        return {'val_loss': loss}

In [None]:
from pytorch_lightning import Trainer

trainer = Trainer(gpus=[3],
                  accelerator='dp',
                  max_epochs=10)

model = Network()
data = DataModule()
trainer.fit(model=model, datamodule=data)

  f"Passing `Trainer(accelerator={self.distributed_backend!r})` has been deprecated"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name    | Type      | Params
--------------------------------------
0 | model   | Ensemble  | 26.7 M
1 | loss_fn | NLLLoss   | 0     
2 | acc     | Accuracy  | 0     
3 | f1      | F1        | 0     
4 | pre     | Precision | 0     
5 | recall  | Recall    | 0     
--------------------------------------
26.7 M    Trainable params
0         Non-trainable params
26.7 M    Total params
106.744   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]