In [1]:
import os
import gc
import cv2
import time
import random
import glob
from PIL import Image
import  matplotlib.pyplot as plt

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
from pytorch_toolbelt import losses as L

# Utils
from tqdm.auto import tqdm

# For Image Models
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

## using gpu:1
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything()

In [2]:
class Customize_Model(nn.Module):
    def __init__(self, model_name, cls):
        super().__init__()
        
    def forward(self, image):
        x = self.model(image)
        return x

In [3]:
def get_test_transform(img_size):
    return A.Compose([
        A.Resize(img_size, img_size),
        ToTensorV2(p=1.0),
    ])


def read_video(path, 
               hitframe,
               video_load_frac=1.0):
    imgs= []
    cap= cv2.VideoCapture(path)
    while cap.isOpened():
        ret, img = cap.read()
        if not ret: break
        img= cv2.resize(img, None, fx=video_load_frac, fy=video_load_frac)
        img= cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        imgs.append(img)
        
    ## make label
    label = [0]*len(imgs)
    for i in hitframe: 
        label[i-3]= 1
        label[i-2]= 1
        label[i-1]= 1
        label[i]= 1
        label[i+1]= 1
    
    return np.array(imgs), np.array(label) ## (img_len, H, W)

def read_label(path):
    df= pd.read_csv(path)
    return df['HitFrame'].values

class Customize_Dataset(Dataset):
    def __init__(self, df, transforms=None, is_train=True):
        self.df = df
        self.transforms = transforms
        self.is_train= is_train
    
    def __getitem__(self, index):
        data = self.df.loc[index]
        hitframe= read_label(data['label'])
        img,label = read_video(data['image_path'],
                               hitframe)
        img = img.transpose(1,2,0)
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
        
        ## convert to 3 channel
        img= img.unsqueeze(dim=0)
        img= img.expand(3,img.shape[1],img.shape[2],img.shape[3])
            
        return {
            'image': torch.tensor(img/255, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.float32),
        }
    
    def __len__(self):
        return len(self.df)

# CFG

In [4]:
CFG= {
    'fold': 0,
    
    'img_size': 640,
    'frame_length': 32,
    
    'TTA': 1,  ## disable TTA= 1
    'model': [
        './test_model/hitframe/csn_s640_d32/cv0_best.pth',
    ]
}
CFG['sample_rate']= int(CFG['frame_length']/2)
CFG['model']= [ torch.load(m, map_location= 'cuda:0') for m in CFG['model'] ]
print(f"length of model: {len(CFG['model'])}")

length of model: 1


# Prepare Dataset

In [5]:
df= pd.read_csv('data/train.csv')
valid_df= df[df['fold']==CFG['fold']].reset_index(drop=True)
print(f'valid dataset: {len(valid_df)}')

valid_dataset= Customize_Dataset(valid_df.iloc[:].reset_index(drop=True), get_test_transform(CFG['img_size']))
valid_loader= DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
valid_df.head()

valid dataset: 148


Unnamed: 0,image_path,label,fold
0,Data/羽球AICUP_001/part1/train\00003\00003.mp4,Data/羽球AICUP_001/part1/train\00003\00003_S2.csv,0.0
1,Data/羽球AICUP_001/part1/train\00005\00005.mp4,Data/羽球AICUP_001/part1/train\00005\00005_S2.csv,0.0
2,Data/羽球AICUP_001/part1/train\00008\00008.mp4,Data/羽球AICUP_001/part1/train\00008\00008_S2.csv,0.0
3,Data/羽球AICUP_001/part1/train\00010\00010.mp4,Data/羽球AICUP_001/part1/train\00010\00010_S2.csv,0.0
4,Data/羽球AICUP_001/part1/train\00014\00014.mp4,Data/羽球AICUP_001/part1/train\00014\00014_S2.csv,0.0


In [6]:
def inference(model, img):
    
    img= torch.unsqueeze(img, 0).cuda()
    for i, m in enumerate(model):
        with torch.no_grad():
            m.eval()
            imgs= torch.cat([img, img.flip(-1), img.flip(-2), img.flip(-1).flip(-2)], dim=0)
            pred= m(imgs[:CFG['TTA']])
            pred= pred.mean(dim=0)
                
        if i==0: preds= pred.sigmoid()
        else: preds+= pred.sigmoid()
            
    pred= preds/len(model)
    pred= pred.cpu().numpy().tolist()
    return pred

In [7]:
valid_df['pred_prob']= None
count= 0
for i, data in enumerate(tqdm(valid_loader)):
    for j in range(len(data['image'])):
        imgs= data['image'][j]
        label= data['label'][j].numpy().tolist()
        valid_df.at[count, 'label']= label
        
        total_pred= []
        indx=0
        for f in range(imgs.shape[1]):
            img= imgs[:,indx:indx+CFG['frame_length']]
            if img.shape[1]<CFG['frame_length']: break
            pred= inference(CFG['model'], img)
            if f==0: total_pred+= pred
            else:
                total_pred= np.array(total_pred)
                total_pred[-CFG['sample_rate']:]+= pred[:CFG['sample_rate']]
                total_pred= np.append(total_pred, pred[-CFG['sample_rate']:])
            indx+= CFG['sample_rate']
            
        ## overlap half area
        total_pred[ CFG['sample_rate']:-CFG['sample_rate'] ]/= 2
        total_pred= total_pred.tolist()
            
        ## fill total_pred length to label
        total_pred+= [0]* (len(label)-len(total_pred))
        valid_df.at[count, 'pred_prob']= total_pred
        count+= 1
valid_df.head()

  0%|          | 0/148 [00:00<?, ?it/s]

  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image':

  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image':

  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image':

  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image':

  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image':

  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),
  'image': torch.tensor(img/255, dtype=torch.float32),


Unnamed: 0,image_path,label,fold,pred_prob
0,Data/羽球AICUP_001/part1/train\00003\00003.mp4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"[0.056670889258384705, 0.02669801004230976, 0...."
1,Data/羽球AICUP_001/part1/train\00005\00005.mp4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"[0.1635291874408722, 0.08632658421993256, 0.06..."
2,Data/羽球AICUP_001/part1/train\00008\00008.mp4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"[0.26486101746559143, 0.2446281909942627, 0.23..."
3,Data/羽球AICUP_001/part1/train\00010\00010.mp4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"[0.2416536509990692, 0.20668581128120422, 0.18..."
4,Data/羽球AICUP_001/part1/train\00014\00014.mp4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"[0.12637192010879517, 0.062402356415987015, 0...."


In [8]:
from metrics import *

aucs= []
temp_df= valid_df[valid_df['pred_prob']!=None]
all_pred= temp_df['pred_prob'].values
all_label= temp_df['label'].values
for i in range(len(temp_df)):
    p= [np.array(all_pred[i])]
    la= [np.array(all_label[i])]
    score= AUC(p, la)
    aucs.append(score)
print(np.mean(aucs))

0.9638379434264727
