# Learn To Synchronize Videos

## Model declaration

In [22]:
from importlib import reload
import torch
import torch.nn as nn
import torchvision.models as models
from torch.optim import lr_scheduler
import sync_net
reload(sync_net)
from sync_net import reset_first_and_last_layers, TripletNet, TripletLoss
from data_loader import get_datasets, get_test_set
from torch.utils.data import DataLoader
from trainer import fit
cuda = torch.cuda.is_available()

torch.cuda.set_device(0)
embedding_net = models.resnet50(pretrained=True)
reset_first_and_last_layers(embedding_net)
model = TripletNet(embedding_net)
model.cuda(0)
model = nn.DataParallel(model).cuda()
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = TripletLoss(margin=0.5)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 100
start_epoch = 0
save_path = r"C:\Users\root\Projects\VideoSynchronizationWithPytorch\trainings\base"

## Load dataset

In [None]:
training_path = r'\\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie'
validation_path = r'\\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\KR-11'
training_set, validation_set = get_datasets(training_path, validation_path)

## Load training state

In [None]:
load_state_path = save_path + r"\training_state_0.pth"
print(load_state_path)
state = torch.load(load_state_path)

start_epoch = int(state['epoch']) + 1
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])

## Train

In [None]:
torch.cuda.empty_cache()  # Doesn't always work to free the GPU memory

In [None]:
train_loader = DataLoader(training_set, batch_size=20, shuffle=True, num_workers=4)
val_loader = DataLoader(validation_set, batch_size=20, shuffle=True, num_workers=4)
fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, start_epoch=start_epoch, save_progress_path=save_path)

## Test trained model

In [2]:
load_state_path = save_path + r"\training_state_3.pth"
print(load_state_path)
state = torch.load(load_state_path)
model.load_state_dict(state['model'])
model.eval()   

C:\Users\root\Projects\VideoSynchronizationWithPytorch\trainings\base\training_state_3.pth


DataParallel(
  (module): TripletNet(
    (embedding_net): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(

In [3]:
test_path = r'\\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie'
test_set = get_test_set(test_path)

105 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\LCA_30LAO25CAU
110 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\LCA_30RAO
104 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\LCA_30RAO25CAU
78 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\LCA_AP
79 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\LCA_LAT
121 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\RCA_AP
113 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\AA-4\export\RCA_LAT
75 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\ABL-5\export\LCA_30LAO25CRA
78 valid frames in \\primnis.gi.polymtl.ca\dfs\cheriet\Images\Cardiologie\Angiographie\ABL-5\export\LCA_30RAO
85 vali

In [23]:
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1)
for batch_index, sequences in enumerate(test_loader):
    # sequences: (batch, video_frame, channel, width, height)
    for i in range(len(sequences[0])):
        embedding = model(sequences[:, i]) # (1, 1000)

torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.S

torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3, 224, 224])
torch.Size([103, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])
torch.Size([1, 103, 3

KeyboardInterrupt: 