# L2CS-Net의 train.py 코드 수정

- Adapted from: https://github.com/Ahmednull/L2CS-Net/blob/main/train.py

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class L2CS_MobileNetV3(nn.Module):
    def __init__(self, num_bins=90, version='large'):
        super(L2CS_MobileNetV3, self).__init__()

        if version == 'large':
            self.feature_extractor = models.mobilenet_v3_large(pretrained=True)
            feature_dim = 960
        elif version == 'small':
            self.feature_extractor = models.mobilenet_v3_small(pretrained=True)
            feature_dim = 576
        else:
            raise ValueError("Invalid MobileNetV3 version. Choose 'large' or 'small'.")

        self.feature_extractor.classifier = nn.Identity()

        self.fc_yaw_gaze = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_bins)
        )

        self.fc_pitch_gaze = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_bins)
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        yaw = self.fc_yaw_gaze(features)
        pitch = self.fc_pitch_gaze(features)
        return pitch, yaw


In [None]:
# === Parameter Grouping Functions ===
def get_ignored_params(model):
    # Ignore first conv layer & batchnorm for fine-tuning
    b = [model.feature_extractor.features[0]]
    for i in range(len(b)):
        for module_name, module in b[i].named_modules():
            if 'bn' in module_name:
                module.eval()
            for name, param in module.named_parameters():
                yield param

def get_non_ignored_params(model):
    # Fine-tune deeper feature extractor layers
    b = [model.feature_extractor.features[1:]]
    for i in range(len(b)):
        for module_name, module in b[i].named_modules():
            if 'bn' in module_name:
                module.eval()
            for name, param in module.named_parameters():
                yield param

def get_fc_params(model):
    # Train only final FC layers
    b = [model.fc_yaw_gaze, model.fc_pitch_gaze]
    for i in range(len(b)):
        for module_name, module in b[i].named_modules():
            for name, param in module.named_parameters():
                yield param


In [None]:
import os
import numpy as np
import cv2


import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image, ImageFilter


class Gaze360(Dataset):
    def __init__(self, path, root, transform, angle, binwidth, train=True):
        self.transform = transform
        self.root = root
        self.orig_list_len = 0
        self.angle = angle
        if train==False:
          angle=90
        self.binwidth=binwidth
        self.lines = []
        if isinstance(path, list):
            for i in path:
                with open(i) as f:
                    print("here")
                    line = f.readlines()
                    line.pop(0)
                    self.lines.extend(line)
        else:
            with open(path) as f:
                lines = f.readlines()
                lines.pop(0)
                self.orig_list_len = len(lines)
                for line in lines:
                    gaze2d = line.strip().split(" ")[5]
                    label = np.array(gaze2d.split(",")).astype("float")
                    if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
                        self.lines.append(line)


        print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        line = self.lines[idx]
        line = line.strip().split(" ")

        face = line[0]
        lefteye = line[1]
        righteye = line[2]
        name = line[3]
        gaze2d = line[5]
        label = np.array(gaze2d.split(",")).astype("float")
        label = torch.from_numpy(label).type(torch.FloatTensor)

        pitch = label[0]* 180 / np.pi
        yaw = label[1]* 180 / np.pi

        img = Image.open(os.path.join(self.root, face))

        # fimg = cv2.imread(os.path.join(self.root, face))
        # fimg = cv2.resize(fimg, (448, 448))/255.0
        # fimg = fimg.transpose(2, 0, 1)
        # img=torch.from_numpy(fimg).type(torch.FloatTensor)

        if self.transform:
            img = self.transform(img)

        # Bin values
        bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
        binned_pose = np.digitize([pitch, yaw], bins) - 1

        labels = binned_pose
        cont_labels = torch.FloatTensor([pitch, yaw])


        return img, labels, cont_labels, name

In [None]:
import os
import argparse
import time
import datetime
import pytz

import torch.utils.model_zoo as model_zoo
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
from tqdm import tqdm




def train_Gaze360(config):
    cudnn.enabled = True
    num_epochs = config["epochs"]
    lr = config['learning_rate']
    batch_size = config['batch_size']
    gpu = torch.device("cuda")
    data_set='gaze360'
    alpha = 1
    output='/content/drive/MyDrive/L2CSMobile/outputs/snapshots/'
    arch = config['architecture']
    snapshot = config['snapshot']
    gaze360label_dir='/content/Gaze361/Label/train.label'
    gaze360image_dir='/content/Gaze361/Image'



    transformations = transforms.Compose([
        transforms.Resize(448),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])



    if data_set=="gaze360":
      model = L2CS_MobileNetV3(90, version='large') # select model
      model.to(gpu)

      # Optimizer gaze
      optimizer_gaze = torch.optim.Adam([
          {'params': get_ignored_params(model), 'lr': 0},
          {'params': get_non_ignored_params(model), 'lr': lr},
          {'params': get_fc_params(model), 'lr': lr}
      ], lr)

      start_epoch = 0

      if snapshot != '':
          checkpoint = torch.load(snapshot)
          model.load_state_dict(checkpoint['model_state_dict'])
          optimizer_gaze.load_state_dict(checkpoint['optimizer_state_dict'])
          start_epoch = checkpoint['epoch']

      dataset=Gaze360(gaze360label_dir, gaze360image_dir, transformations, 180, 4)
      print('Loading data.')
      train_loader_gaze = DataLoader(
          dataset=dataset,
          batch_size=int(batch_size),
          shuffle=True,
          num_workers=0,
          pin_memory=True)
      torch.backends.cudnn.benchmark = True

      timezone = pytz.timezone('Asia/Seoul')
      seoul_time = datetime.datetime.now(timezone)
      summary_name = '{}_{}'.format('L2CSM-gaze360-', seoul_time.strftime('%Y-%m-%d_%H-%M-%S'))
      output=os.path.join(output, summary_name)
      if not os.path.exists(output):
          os.makedirs(output)


      criterion = nn.CrossEntropyLoss().to(gpu)
      reg_criterion = nn.MSELoss().to(gpu)
      softmax = nn.Softmax(dim=1).to(gpu)
      idx_tensor = [idx for idx in range(90)]
      idx_tensor = Variable(torch.FloatTensor(idx_tensor)).to(gpu)



      configuration = f"\ntrain configuration, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}, loader={len(train_loader_gaze)}------------------------- \n"
      print(configuration)
      # wandb.watch(model, criterion, log="all", log_freq=10)
      for epoch in range(start_epoch, num_epochs):
          sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0

          train_loader_gaze_tqdm = tqdm(train_loader_gaze, desc=f"Epoch {epoch+1}/{num_epochs}")


          for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze_tqdm):
              images_gaze = Variable(images_gaze).to(gpu)

              # Binned labels
              label_pitch_gaze = Variable(labels_gaze[:, 0]).to(gpu)
              label_yaw_gaze = Variable(labels_gaze[:, 1]).to(gpu)

              # Continuous labels
              label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).to(gpu)
              label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).to(gpu)

              pitch, yaw = model(images_gaze)

              # Cross entropy loss
              loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
              loss_yaw_gaze = criterion(yaw, label_yaw_gaze)

              # MSE loss
              pitch_predicted = softmax(pitch)
              yaw_predicted = softmax(yaw)

              pitch_predicted = \
                  torch.sum(pitch_predicted * idx_tensor, 1) * 4 - 180
              yaw_predicted = \
                  torch.sum(yaw_predicted * idx_tensor, 1) * 4 - 180

              loss_reg_pitch = reg_criterion(
                  pitch_predicted, label_pitch_cont_gaze)
              loss_reg_yaw = reg_criterion(
                  yaw_predicted, label_yaw_cont_gaze)

              # Total loss
              loss_pitch_gaze += alpha * loss_reg_pitch
              loss_yaw_gaze += alpha * loss_reg_yaw

              sum_loss_pitch_gaze += loss_pitch_gaze
              sum_loss_yaw_gaze += loss_yaw_gaze

              loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
              grad_seq = [torch.tensor(1.0).to(gpu) for _ in range(len(loss_seq))]
              optimizer_gaze.zero_grad(set_to_none=True)
              torch.autograd.backward(loss_seq, grad_seq)
              optimizer_gaze.step()
              # scheduler.step()

              iter_gaze += 1

              train_loader_gaze_tqdm.set_postfix({
                  "Loss Pitch": sum_loss_pitch_gaze / iter_gaze,
                  "Loss Yaw": sum_loss_yaw_gaze / iter_gaze
              })

              # if (i+1) % 100 == 0:
              #     print('Epoch [%d/%d], Iter [%d/%d] Losses: '
              #         'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
              #             epoch+1,
              #             num_epochs,
              #             i+1,
              #             len(dataset)//batch_size,
              #             sum_loss_pitch_gaze/iter_gaze,
              #             sum_loss_yaw_gaze/iter_gaze
              #         )
              #         )


          if epoch % 1 == 0 and epoch < num_epochs:
              print('Taking snapshot...',
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer_gaze.state_dict(),
                    }, output + '/checkpoint_epoch_{}.pth'.format(epoch)))
              wandb.log({
                      "loss_pitch_gaze": sum_loss_pitch_gaze/iter_gaze,
                      "loss_yaw_gaze": sum_loss_yaw_gaze/iter_gaze,}
                      ,step=epoch+1
                  )
              wandb.save(output + '/checkpoint_epoch_{}.pth'.format(epoch))



In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')
!unzip -u /content/drive/MyDrive/L2CS-Net/Gaze360.zip -d /content/

In [None]:
import wandb
wandb.login()

In [None]:
def run():
    config  = {
      'epochs': 50,
      'batch_size': 16,
      'learning_rate': 1e-5,
      'dataset': 'Gaze360',
      'architecture': 'MobileNetV3',
      'optimizer': 'Adam',
      'snapshot': ''
    }
    wandb.init(project='Itda-L2CS',name='MobilenetV3L')
    # wandb.init(project='Itda-L2CS',name='MobilenetV3L', config=config, id="", resume="must")
    train_Gaze360(config)
    wandb.finish()

In [None]:
run()