In [1]:
from src.lisa.utils.logger import get_logger
import torch
from torch.nn import CrossEntropyLoss
from torchvision import datasets, transforms
from torch.utils.data._utils.collate import default_collate
import numpy
from PIL import Image
from typing import Optional
from efficientnet_pytorch import EfficientNet
from dataclasses import dataclass

In [2]:
image_size = 224

#### read image

In [3]:
image_path = "data/makeup color test/Train/face color test/3.jpg"
img = Image.open(image_path)

#### image process

In [4]:
tfms = transforms.Compose([transforms.Resize(image_size), transforms.CenterCrop(image_size), 
                           transforms.ToTensor(),
                           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
img = tfms(img).unsqueeze(0)


In [5]:
# 準備訓練資料
# ImageFolder假設指定路徑下有多個資料夾，每個資料夾內為同一類的圖片，資料加名稱類別名
data_path = "data/makeup color test/Train/"
train_dataset = datasets.ImageFolder(data_path,
                                     transforms.Compose([
                                        transforms.Resize(image_size),
                                        transforms.CenterCrop(image_size),
                                        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                     ]))
train_dataloader =  torch.utils.data.DataLoader(train_dataset, 
                                                batch_size=2, 
                                                shuffle=True, 
                                                num_workers=4)

In [6]:
for i,j in train_dataloader:
    print(i)
    print(j)

tensor([[[[ 0.4166,  0.3138,  0.3823,  ...,  2.2489, -2.1179, -2.1179],
          [ 0.5364, -0.6965, -1.7240,  ...,  2.2489, -2.1179, -2.1179],
          [ 2.1804,  1.6324, -0.8507,  ...,  2.2489, -2.1179, -2.1179],
          ...,
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],

         [[ 0.5378,  0.4503,  0.6954,  ...,  2.4286, -2.0357, -2.0357],
          [ 0.5203, -0.8627, -1.3179,  ...,  2.4286, -2.0357, -2.0357],
          [ 2.3235,  1.4657, -0.7927,  ...,  2.4286, -2.0357, -2.0357],
          ...,
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],

         [[ 0.7228,  0.6705,  1.1062,  ...,  2.6400, -1.8044, -1.8044],
          [ 0.5485, -0.9156, -

#### load EfficientNet

In [7]:
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes = 5)
image_size = EfficientNet.get_image_size('efficientnet-b0') # 224
model.extract_features(img)

Loaded pretrained weights for efficientnet-b0


tensor([[[[-1.3452e-01, -2.4137e-01, -2.5781e-01,  ..., -1.5725e-01,
           -1.5298e-01, -1.8575e-01],
          [-1.2318e-01, -2.1992e-01, -7.1795e-02,  ..., -2.2001e-01,
           -1.3432e-01, -6.8094e-02],
          [-1.1476e-01, -4.0598e-02, -5.9860e-04,  ..., -8.5880e-02,
           -5.6209e-02, -1.2070e-02],
          ...,
          [ 9.4377e-01, -2.7345e-01, -1.3477e-01,  ..., -2.2265e-01,
           -2.3221e-01, -1.8630e-01],
          [ 6.4874e-02, -2.7123e-01, -2.7594e-01,  ..., -4.1425e-02,
            2.6047e+00,  2.6346e+00],
          [-2.7560e-01, -5.7917e-02, -1.5357e-02,  ..., -4.1297e-03,
            9.3423e-01,  2.4052e+00]],

         [[ 2.7389e+00,  5.4472e-01, -1.5927e-01,  ..., -1.3962e-01,
           -1.9566e-01, -1.6196e-01],
          [ 4.6547e-01, -2.7000e-01, -2.3767e-01,  ..., -1.1946e-01,
           -9.5869e-02, -1.3298e-01],
          [-1.3458e-01,  2.9402e+00,  6.5549e+00,  ..., -1.7072e-01,
           -8.7656e-02, -2.1018e-02],
          ...,
     

#### modify EfficientNet

In [8]:
from transformers.file_utils import ModelOutput
from dataclasses import dataclass


In [9]:
@dataclass
class EfficientNetOutput(ModelOutput):
    loss: torch.FloatTensor = None
    logits: torch.FloatTensor = None
    features: torch.FloatTensor = None

In [10]:
class EfficientNetModify(EfficientNet):
    def __init__(self, blocks_args=None, global_params=None):
        super().__init__(blocks_args, global_params)
        
    def forward(self, 
                inputs=None, 
                labels=None):
        print("QQQQQQQQQQQQQQQQQQQQQQQQ")
        features = super().forward(inputs)
        
        if self._global_params.include_top:
            logits = features
            features = None
        else:
            x = features.flatten(start_dim=1)
            x = self._dropout(x)
            logits = self._fc(x)
        
        loss = None
        if labels != None:
            loss_fun = CrossEntropyLoss()
            loss = loss_fun(logits, labels)
        
        return EfficientNetOutput(
            loss=loss,
            logits=logits,
            features=features
        ) 

In [11]:
model_new = EfficientNetModify.from_pretrained('efficientnet-b0', num_classes = 5, include_top=False)
model_new(img, torch.tensor([1]))

Loaded pretrained weights for efficientnet-b0


EfficientNetOutput([('loss', tensor(1.5200, grad_fn=<NllLossBackward>)),
                    ('logits',
                     tensor([[-0.0974,  0.1277,  0.1516,  0.0656, -0.0837]],
                            grad_fn=<AddmmBackward>)),
                    ('features', tensor([[[[0.1790]],
                     
                              [[0.2258]],
                     
                              [[0.0878]],
                     
                              ...,
                     
                              [[0.1900]],
                     
                              [[0.0906]],
                     
                              [[0.1389]]]], grad_fn=<MeanBackward1>))])

In [12]:
for i,j in train_dataloader:
    output = model_new(inputs = i, labels = torch.tensor(j)-1)
    print(output)

  


EfficientNetOutput(loss=tensor(1.5492, grad_fn=<NllLossBackward>), logits=tensor([[-0.0697,  0.0590,  0.1366,  0.0343, -0.1071],
        [-0.2304,  0.0861,  0.1173,  0.1579, -0.1010]],
       grad_fn=<AddmmBackward>), features=tensor([[[[-0.0724]],

         [[-0.0979]],

         [[ 0.1475]],

         ...,

         [[ 0.1349]],

         [[ 0.2845]],

         [[ 0.5453]]],


        [[[ 0.2559]],

         [[ 0.5480]],

         [[ 0.2489]],

         ...,

         [[ 0.2155]],

         [[-0.0297]],

         [[-0.1406]]]], grad_fn=<MeanBackward1>))
EfficientNetOutput(loss=tensor(1.7072, grad_fn=<NllLossBackward>), logits=tensor([[-2.4262e-01,  2.1429e-04,  2.7764e-01,  1.4214e-01, -1.2547e-01],
        [ 1.3317e-01,  1.6521e-01,  1.2484e-01, -5.3809e-03,  2.4409e-02]],
       grad_fn=<AddmmBackward>), features=tensor([[[[-0.0634]],

         [[ 0.3243]],

         [[ 0.2773]],

         ...,

         [[ 0.2529]],

         [[ 0.0617]],

         [[ 0.2084]]],


        [[[-0.08

#### Trainer

In [13]:
from transformers import Trainer, TrainingArguments, EvaluationStrategy

In [14]:
args_dict = {
    "num_train_epochs":1,          # 訓練代數
    "per_device_train_batch_size":2,    # train時的batch size
    "per_device_eval_batch_size":2,    # eval時的batch size
    "gradient_accumulation_steps":8,    # 每幾個batch update一次參數
    "warmup_steps":500,          # 前幾個batch要做warm up
    "weight_decay":0.00001,          # learning rate
    "eval_steps":500,            # 每幾個step要eval 預設500
    "save_steps":500,            # 每幾個step要save 預設500
    "logging_steps":100,
    "evaluation_strategy":EvaluationStrategy.STEPS,   # 用STEPS來判斷是否要eval
    "dataloader_num_workers":4,      # 開幾個CPU做dataloader
}

In [15]:
train_setting = {
                # -----data setting-----
                "train_dataset" : train_dataset,
#                 "eval_dataset" : valid_dataset,
#                 "train_sampler" : weight_sampler,
#                 "compute_metrics":metric.entity_mention_metric
                }

In [16]:
training_args = TrainingArguments(
                                  output_dir="./results",  # 輸出模型的資料夾
                                  **args_dict
                                 )

In [17]:
class MyTrainer(Trainer):
    def _prepare_inputs(self, inputdata):
        print("QQQ")
        print(inputdata)
        print(inputdata[0].shape)
        if isinstance(inputdata, dict):
            return super()._prepare_inputs(inputdata)
            print(inputdata)
        elif isinstance(inputdata, (torch.Tensor,list)) :
            print("~~~~~~~~~~")
            inputs = {}
            inputs["inputs"] = torch.tensor(inputdata[0])
            inputs["labels"] = torch.tensor(inputdata[1])
            print(inputs)
            return super()._prepare_inputs(inputs)

In [18]:
# get trainer
trainer = MyTrainer(
                    model=model,
                    args=training_args,
                    train_dataset = train_dataset,
                    data_collator = default_collate
                  )

In [19]:
trainer.train()

QQQ
[tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
          ...,
          [ 1.8379,  0.7419, -1.3302,  ..., -2.1179, -2.1179, -2.1179],
          [ 2.0777,  1.7865,  1.5468,  ..., -2.1179, -2.1179, -2.1179],
          [ 2.2489,  2.2489,  2.2489,  ..., -2.1179, -2.1179, -2.1179]],

         [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          ...,
          [ 1.8683,  1.2731, -1.0553,  ..., -2.0357, -2.0357, -2.0357],
          [ 2.2010,  2.0434,  1.7108,  ..., -2.0357, -2.0357, -2.0357],
          [ 2.4286,  2.4286,  2.4286,  ..., -2.0357, -2.0357, -2.0357]],

         [[-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
          [-1.8044, -1.80

  if sys.path[0] == '':
  del sys.path[0]


TypeError: forward() got an unexpected keyword argument 'labels'