In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchinfo import summary
import torchsummary
from torchvision import transforms
from video_dataset import VideoFrameDataset, ImglistToTensor
from pytorchtools import EarlyStopping

from tqdm import tqdm
import random
import numpy as np

In [None]:
# Parameters
TEST_RATIO = 0.3

MANUAL_SEED = 888
BATCH_SIZE = 4
NUMBER_OF_SEGMENTS = 8

BATCH_SIZE = 32
EPOCHS = 30

In [None]:
manualSeed = MANUAL_SEED
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
np.random.seed(manualSeed)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Now running with device = {device}')

In [None]:
import random

def split_dataset():
    dir = '/content/drive/MyDrive/Project/MLB/Data/Segment'
    dst_path = dir + '/all_test'

    all = []

    with open(f'{dst_path}/annotations.txt', 'r') as f:
        for item in f:
            all.append(item.strip())

    n = int(len(all) * TEST_RATIO)
    test = random.sample(all, n)
    tran = []
    for i in all:
        if i not in test:
            tran.append(i)

    print(f'Total dataset size = {len(all)}')
    print(f'==================================')
    print(f'Train size = {len(tran)}')
    print(f'Test  size = {len(test)}')

    with open(f'{dst_path}/annotations_train.txt', 'w') as f:
        for item in tran:
            f.write("%s\n" % item)

    with open(f'{dst_path}/annotations_test.txt', 'w') as f:
        for item in test:
            f.write("%s\n" % item)


split_dataset()

Total dataset size = 4456
Train size = 3120
Test  size = 1336


In [None]:
preprocess = transforms.Compose([
        ImglistToTensor(),          # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize(299),     # image batch, resize smaller edge to 299
        transforms.CenterCrop(299), # image batch, center crop to square 299x299
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
dir = '/content/drive/MyDrive/Project/MLB/Data/Segment/all_test'

train_dataset = VideoFrameDataset(
    root_path = dir,
    annotationfile_path = os.path.join(dir, "annotations_train.txt"),
    num_segments = NUMBER_OF_SEGMENTS,
    frames_per_segment = 1,
    imagefile_template='frame_{0:012d}.jpg',
    transform = preprocess,
    random_shift = True,
    test_mode = False
)
test_dataset = VideoFrameDataset(
    root_path = dir,
    annotationfile_path = os.path.join(dir, "annotations_test.txt"),
    num_segments = NUMBER_OF_SEGMENTS,
    frames_per_segment = 1,
    imagefile_template='frame_{0:012d}.jpg',
    transform = preprocess,
    random_shift = True,
    test_mode = False
)