In [1]:
import gc
import random
import pandas as pd
import numpy as np
import os
import cv2

import torch
import torch.nn as nn
from torch import Tensor
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torchvision.models as models

from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import warnings
warnings.filterwarnings(action='ignore') 

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
CFG = {
    'VIDEO_LENGTH':50, # 10프레임 * 5초
    'IMG_SIZE_H':208, # slow : 208, 256
    'IMG_SIZE_W':416,
    'EPOCHS':80,
    'LEARNING_RATE':1e-6,
    'BATCH_SIZE':16,
    'SEED':909
}

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [5]:
df = pd.read_csv('./datasets/train_CE-W-T_aug2.csv')

error_data = [8, 124, 330, 387, 486, 1113, 1533, 2292,
                8+2698, 124+2698, 330+2698, 387+2698, 486+2698, 1113+2698, 1533+2698, 2292+2698]

for error in error_data:
    df = df.drop(error)
    
df = df.reset_index(drop=True)

In [6]:
class CustomDataset(Dataset):
    def __init__(self, video_path_list, label_list):
        self.video_path_list = video_path_list
        self.label_list = label_list
        
    def __getitem__(self, index):
        frames = self.get_video(self.video_path_list[index])
        
        if self.label_list is not None:
            label = self.label_list[index]
            return frames, label
        else:
            return frames
        
    def __len__(self):
        return len(self.video_path_list)
    
    def get_video(self, path):
        frames = []
        cap = cv2.VideoCapture(path)        
        for _ in range(CFG['VIDEO_LENGTH']):
            _, img = cap.read()
            img = cv2.resize(img, (CFG['IMG_SIZE_W'], CFG['IMG_SIZE_H']))
            img = img / 255.
            frames.append(img)
        return torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2)

In [7]:
class BaseModel2(nn.Module):
    def __init__(self, num_classes=None):
        super(BaseModel2, self).__init__()
        # self.base_model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', pretrained=True).to(device) #2048
        self.base_model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True).to(device) #2048
        self.base_model.blocks[5].proj = nn.Sequential(
            nn.Linear(2048, 400),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(400, num_classes)
        ).to(device)
        
    def forward(self, x):        
        x = self.base_model(x)        
        return x

In [8]:
df_crash_ego = df.loc[:, ['sample_id', 'video_path', 'crash_ego']]

df_weather = df.loc[:, ['sample_id', 'video_path', 'weather']]
df_unlabled_weather = df_weather[df_weather['weather']==-1]
df_unlabled_weather = df_unlabled_weather.reset_index(drop=True)
df_weather = df_weather[df_weather['weather']>-1]

df_timing = df.loc[:, ['sample_id', 'video_path', 'timing']]
df_unlabled_timing = df_timing[df_timing['timing']==-1]
df_unlabled_timing = df_unlabled_timing.reset_index(drop=True)
df_timing = df_timing[df_timing['timing']>-1]

print(len(df_crash_ego), len(df_weather), len(df_timing))
print(len(df_unlabled_weather), len(df_unlabled_timing))

5380 2132 2156
3248 3224


In [9]:
df_weather['weather'].value_counts()

0    1656
1     297
2     179
Name: weather, dtype: int64

In [10]:
df_timing['timing'].value_counts()

0    1801
1     355
Name: timing, dtype: int64

In [12]:
## Weather

# Load trained weights
weather_model_weights = torch.load('./models/x3d_s-0.6735-208_416/CE-W-T_[x3d_s]_[weather]_[score0.8889]_[loss0.0323].pt')
weather_model = BaseModel2(num_classes=3)
weather_model.load_state_dict(weather_model_weights)

unlabeld_dataset = CustomDataset(df_unlabled_weather['video_path'].values, None)
unlabeld_loader = DataLoader(unlabeld_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=8)

weather_model.to(device)
weather_model.eval()

prob = 0.9

# Inference
preds = []
with torch.no_grad():
    for videos in tqdm(iter(unlabeld_loader)):
        videos = videos.to(device)

        logit = weather_model(videos)
        
        logit = torch.softmax(logit, dim=1)
        preds += logit.detach().cpu().numpy().tolist()

# check proba & idx
max_proba = np.array(preds).max(axis=1) # 확률
pseudo_idx = np.where(max_proba > prob, 1, 0) # 확률 > 0.n 초과 idx

# Make pseudo dataframe
df_weather_pseudo = df_weather.copy()
added = []
for idx, row in df_unlabled_weather[:].iterrows():
    label = np.array(preds).argmax(axis=1)[idx]
    if (pseudo_idx[idx] == 1) and (label != 0):
        df_weather_pseudo = df_weather_pseudo.append({'sample_id': row['sample_id'],
                                                      'video_path': row['video_path'],
                                                      'weather': label
                                                     }, ignore_index=True)
        added.append([row['sample_id'], np.array(preds).argmax(axis=1)[idx]])

df_weather_pseudo.to_csv(f'./datasets/train_weather_pseudo_{prob}.csv', index=False)
print(len(added))
df_weather_pseudo['weather'].value_counts()

Using cache found in /root/.cache/torch/hub/facebookresearch_pytorchvideo_main


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=203.0), HTML(value='')))


200


0    1656
1     359
2     317
Name: weather, dtype: int64

In [14]:
df_weather_pseudo[-50:]

Unnamed: 0,sample_id,video_path,weather
2282,TRAIN_4176,./datasets/train_aug/TRAIN_4176.mp4,2
2283,TRAIN_4182,./datasets/train_aug/TRAIN_4182.mp4,2
2284,TRAIN_4255,./datasets/train_aug/TRAIN_4255.mp4,1
2285,TRAIN_4274,./datasets/train_aug/TRAIN_4274.mp4,2
2286,TRAIN_4289,./datasets/train_aug/TRAIN_4289.mp4,2
2287,TRAIN_4292,./datasets/train_aug/TRAIN_4292.mp4,2
2288,TRAIN_4297,./datasets/train_aug/TRAIN_4297.mp4,2
2289,TRAIN_4327,./datasets/train_aug/TRAIN_4327.mp4,1
2290,TRAIN_4351,./datasets/train_aug/TRAIN_4351.mp4,1
2291,TRAIN_4368,./datasets/train_aug/TRAIN_4368.mp4,1


In [13]:
## Timing

# Load trained weights
timing_model_weights = torch.load('./models/x3d_s-0.6735-208_416/CE-W-T_[x3d_s]_[timing]_[score0.9730]_[loss0.0252].pt')
timing_model = BaseModel2(num_classes=2)
timing_model.load_state_dict(timing_model_weights)

unlabeld_dataset = CustomDataset(df_unlabled_timing['video_path'].values, None)
unlabeld_loader = DataLoader(unlabeld_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=8)

timing_model.to(device)
timing_model.eval()

prob = 0.8

# Inference
preds = []
with torch.no_grad():
    for videos in tqdm(iter(unlabeld_loader)):
        videos = videos.to(device)

        logit = timing_model(videos)
        
        logit = torch.softmax(logit, dim=1)
        preds += logit.detach().cpu().numpy().tolist()

# check proba & idx
max_proba = np.array(preds).max(axis=1) # 확률
pseudo_idx = np.where(max_proba > prob, 1, 0) # 확률 > 0.n 초과 idx

# Make pseudo dataframe
df_timing_pseudo = df_timing.copy()
added = []
for idx, row in df_unlabled_timing[:].iterrows():
    label = np.array(preds).argmax(axis=1)[idx]
    if (pseudo_idx[idx] == 1) and (label != 0):
        df_timing_pseudo = df_timing_pseudo.append({'sample_id': row['sample_id'],
                                                      'video_path': row['video_path'],
                                                      'timing': label
                                                     }, ignore_index=True)
        added.append([row['sample_id'], np.array(preds).argmax(axis=1)[idx]])

df_timing_pseudo.to_csv(f'./datasets/train_timing_pseudo_{prob}.csv', index=False)
print(len(added))
df_timing_pseudo['timing'].value_counts()

Using cache found in /root/.cache/torch/hub/facebookresearch_pytorchvideo_main


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=202.0), HTML(value='')))


1066


0    1801
1    1421
Name: timing, dtype: int64