In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import warnings
warnings.filterwarnings("ignore")

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
# #         print(os.path.join(dirname, filename))
#         pass

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn

In [None]:
torch.cuda.is_available()

In [None]:
from transformers import ResNetModel,ResNetConfig,MobileNetV2Config, MobileNetV2Model,MobileNetV1Config, MobileNetV1Model

In [None]:
poses = []
poses_to_path_train = {}
poses_to_path_test = {}
poses_to_path_val = {}
for dirname, _, filenames in os.walk('/kaggle/input/yoga-82/train/'):
    _dirname = dirname.split('/')[-1].lower()
    poses.append(_dirname)
    poses_to_path_train[_dirname] = []
    for filename in filenames:
        poses_to_path_train[_dirname].append(os.path.join(dirname, filename))
for dirname, _, filenames in os.walk('/kaggle/input/yoga-82/test/'):
    _dirname = dirname.split('/')[-1].lower()
    poses_to_path_test[_dirname] = []
    for filename in filenames:
        poses_to_path_test[_dirname].append(os.path.join(dirname, filename))

for dirname, _, filenames in os.walk('/kaggle/input/yoga-82/valid/'):
    _dirname = dirname.split('/')[-1].lower()
    poses_to_path_val[_dirname] = []
    for filename in filenames:
        poses_to_path_val[_dirname].append(os.path.join(dirname, filename))


In [None]:
poses=[i.split('/')[-1].lower() for i in poses[1:]]
del poses_to_path_train['']
del poses_to_path_test['']
del poses_to_path_val['']

In [None]:
poses_to_idx = dict((j,i) for i,j in enumerate(poses))

In [None]:
from torch.utils.data import Dataset

In [None]:
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class YogaPoseDataset(Dataset):
    def __init__(self,poses_to_path,poses_to_idx,poses,transform=None):
        self.poses_to_path = poses_to_path
        self.max_size = max(len(i) for i in poses_to_path)
        self.poses_to_idx = poses_to_idx
        self.class_num = len(self.poses_to_idx)
        self.poses=poses
        self.transform=transform
    def __len__(self):
        return self.max_size*self.class_num
    def __getitem__(self,idx):
        pose_class = idx%self.class_num
        img_id = (idx//self.class_num)%len(self.poses_to_path[self.poses[pose_class]])
        img = Image.open(self.poses_to_path[self.poses[pose_class]][img_id]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img,pose_class
        

In [None]:
from torchvision import transforms as T

In [None]:
transform = T.Compose([
        T.Resize((128,128), interpolation=T.InterpolationMode.BILINEAR),
        T.ToTensor(),
        T.Normalize([0.5], [0.5]),
        T.RandomErasing(p=1.0,value="random"),
        T.ColorJitter(),
        T.RandomRotation((-45,45)),
        T.RandomHorizontalFlip(p=0.5)
]
)

val_trans = T.Compose([
        T.Resize((128,128), interpolation=T.InterpolationMode.BILINEAR),
        T.ToTensor(),
        T.Normalize([0.5], [0.5]),
]
)

In [None]:
train_ds = YogaPoseDataset(poses_to_path_train,poses_to_idx,poses,transform)
val_ds = YogaPoseDataset(poses_to_path_val,poses_to_idx,poses,val_trans)


In [None]:
from torch.utils.data import random_split

In [None]:
# train_ds,val_ds = random_split(ds,[0.8,0.2])

In [None]:
# model = ResNetModel(ResNetConfig(depths=[2,2,2,2]))
model = MobileNetV2Model(MobileNetV2Config())
# model = MobileNetV1Model(MobileNetV1Config())
mlp = nn.Sequential(
#         # mobilenetv2
    nn.BatchNorm1d(1280),
    nn.Dropout(0.7),
    nn.ReLU(),
    nn.Linear(1280,256),
#     # mobilenetv2
#     nn.BatchNorm1d(1024),
#     nn.Dropout(0.5),
#     nn.Linear(1024,256),
#     nn.BatchNorm1d(model.config.hidden_sizes[-1]),
#     nn.ReLU(),
#     nn.Dropout(0.5),
#     nn.Linear(model.config.hidden_sizes[-1],train_ds.class_num),
# #     nn.Linear(model.config.hidden_sizes[-1],256),
    nn.Dropout(0.7),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Linear(256,train_ds.class_num),
)

In [None]:
from torch.utils.data import DataLoader 

In [None]:
BATCH_SIZE = 64
train_dl = DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)
val_dl  = DataLoader(val_ds,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)

In [None]:
import torch.optim

In [None]:
opt = torch.optim.AdamW([*model.parameters(),*mlp.parameters()],lr=5e-4,weight_decay=5e-4)

In [None]:
device = "cuda"

model.to(device)
mlp.to(device)


for epoch in range(500):
    model.train()
    mlp.train()
    tr_loss = []
    res = []
    for step, batch in enumerate(train_dl):
        imgs,labels = batch
        encodings = model(imgs.to(device))
        logits = mlp(encodings.pooler_output.flatten(1))
        loss = nn.functional.cross_entropy(logits,labels.to(device))
        tr_loss.append(loss.detach().cpu())
        res.append(torch.argmax(logits,dim=1).detach()==labels.to(device))
        loss.backward()
        opt.step()
        opt.zero_grad()
    tr_loss = torch.stack(tr_loss).float().mean() 
    tr_acc = torch.cat([i for i in res]).float().mean()
    model.eval()
    mlp.eval()
    res = []
    val_loss = []
    for step, batch in enumerate(val_dl):
        with torch.no_grad():
            imgs,labels = batch
            encodings = model(imgs.to(device))
            logits = mlp(encodings.pooler_output.flatten(1))
            loss = nn.functional.cross_entropy(logits,labels.to(device))
            res.append(torch.argmax(logits,dim=1)==labels.to(device))
            val_loss.append(loss.detach())
    print(f"{epoch} Tr loss: {tr_loss} Tr acc: {tr_acc}","Val loss: ",torch.stack(val_loss).float().mean(),"Val acc: ",torch.cat([i for i in res]).float().mean())
        