### 预测代码

In [2]:
import os
import time
import torch
import einops
import random
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
from natsort import natsorted
from tensorboardX import SummaryWriter
from datetime import datetime
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from data import *
from utils import warp, update_train_log, write_train_log, update_eval_log, write_eval_log, print_eval_log

def main(args):
  verbose = args.verbose
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

  # DATASET
  images = [x for x in natsorted(os.listdir(args.dir)) if ('.jpg' in x) or ('.png' in x)]

  # MODEL
  model_stn = models.resnet50(pretrained=True)
  model_stn.fc = nn.Linear(2048, 8)
  model = models.resnet50(pretrained=True)
  model.fc = nn.Linear(2048, 720)
  resume_path = '../models/{}.pth'.format(args.verbose)
  stn_resume_path = '../models/{}_st.pth'.format(args.verbose)
  model.load_state_dict(torch.load(resume_path))
  model_stn.load_state_dict(torch.load(stn_resume_path))
  model_stn.to(device)
  model.to(device)

  for img_name in images:
    with torch.no_grad():
      model.eval()
      model_stn.eval()

      #MODEL
      img = cv2.imread(os.path.join(args.dir, img_name))
      img = cv2.resize(img, (224, 224))/255.
      img = einops.rearrange(img, 'h w c -> c h w')
      img = torch.Tensor(img)
      img = img.float().to(device)
      img = torch.unsqueeze(img, 0)

      pred_st = model_stn(img)
      pred_st = torch.cat([pred_st,torch.ones(1,1).to(device)], 1)
      Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
      img_ = warp(img, Minv_pred)
      pred = model(img_)

      #top 3 predictions
      max_pred = torch.argsort(pred, dim=1, descending=True)
      max_pred = max_pred[0,:3]
      max_h = max_pred[0] // 60
      max_m = max_pred[0] % 60

      print(img_name, max_h.cpu().numpy(), max_m.cpu().numpy())
      
      #img = einops.rearrange(img[0], 'c h w -> h w c').cpu().numpy()[:,:,::-1] * 255 
      #img_ = einops.rearrange(img_[0], 'c h w -> h w c').cpu().numpy()[:,:,::-1] * 255

      #uncomment this to save image
      #os.makedirs('../viz/{}/{}'.format(verbose,names[i]), exist_ok=True)
      #if idx < 100:
      #cv2.imwrite('../viz/{}/{}/{}_{}_{}.png'.format(verbose,names[i],idx, int(max_pred[0]), int(hr*60+mn)), img)
      #cv2.imwrite('../viz/{}/{}/{}_w.png'.format(verbose,names[i],idx), img_)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--verbose', type=str, default='full+++')
    parser.add_argument('--dir', type=str, default='../data/demo')

    args = parser.parse_args(args=[])
    main(args)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100%|█████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:06<00:00, 15.4MB/s]


Back to the Future 10.03 p.m.png 10 4
Barton Fink 7.59 a.m.png 7 59
Basic 2.55 a.m. 1.png 2 55
Batman 3.12 a.m.png 3 12
Beauty and the Beast 7.00 p.m.png 6 59
Beetlejuice 6.00 p.m.png 6 1
Before the Devil Knows You're Dead 9.12 a.m.png 9 11
Being There 6.49 p.m.png 6 49
Big Daddy 10.41 a.m.png 10 40
Big Fish 12.03 p.m.png 0 4
Black Narcissus 5.58 a.m.png 5 57
Black Narcissus 6.01 p.m.png 6 0
Blackboard Jungle 3.25 p.m.png 3 25
Brewster's Millions 11.57 p.m.png 11 56
Buffalo '66 11.08 p.m.png 11 7
Bullitt 6.51 p.m.png 6 52
Burn After Reading 4.53 p.m.png 4 53


### 评估代码

In [3]:
import os
import time
import torch
import einops
import random
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
from tensorboardX import SummaryWriter
from datetime import datetime
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from data import *
from utils import warp, update_train_log, write_train_log, update_eval_log, write_eval_log, print_eval_log

def main(args):
  verbose = args.verbose
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

  # DATASET
  coco_dataset = ClockEval('coco')
  openimg_dataset = ClockEval('openimages')
  movie_dataset = ClockEval('clockmovies')
  coco_loader = DataLoader(coco_dataset, batch_size=1, shuffle=False)
  openimg_loader = DataLoader(openimg_dataset, batch_size=1, shuffle=False)
  movie_loader = DataLoader(movie_dataset, batch_size=1, shuffle=False)

  # MODEL
  model_stn = models.resnet50(pretrained=True)
  model_stn.fc = nn.Linear(2048, 8)
  model = models.resnet50(pretrained=True)
  model.fc = nn.Linear(2048, 720)
  resume_path = '../models/{}.pth'.format(args.verbose)
  stn_resume_path = '../models/{}_st.pth'.format(args.verbose)
  model.load_state_dict(torch.load(resume_path))
  model_stn.load_state_dict(torch.load(stn_resume_path))
  model_stn.to(device)
  model.to(device)

  names = ['COCO', 'OpenImages', 'ClockMovies']
  for i, vloader in enumerate([coco_loader, openimg_loader]):
    eval_log = {'top_1': [],'top_2': [],'top_3': [],'top_1_hr': [], 'top_1_min': [], 'iou50': []}
    for idx, val_sample in enumerate(vloader):
      with torch.no_grad():
        model.eval()
        model_stn.eval()
        
        img, hour, minute, iou50 = val_sample
        img = img.float().to(device)
        hr = hour.type(torch.long).to(device)
        mn = minute.type(torch.long).to(device)

        #MODEL
        pred_st = model_stn(img)
        pred_st = torch.cat([pred_st,torch.ones(1,1).to(device)], 1)
        Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
        img_ = warp(img, Minv_pred)
        pred = model(img_)


        #top 3 predictions
        max_pred = torch.argsort(pred, dim=1, descending=True)
        max_pred = max_pred[0,:3]
        max_h = max_pred[0] // 60
        max_m = max_pred[0] % 60

        minute_err = torch.sum(torch.abs(max_m - mn))
        both_err = torch.abs(max_pred - (hr * 60 + mn))
        top_1 = float(both_err[0] <= 1) + float(both_err[0] == 719)
        top_2 = float(both_err[1] <= 1) + float(both_err[1] == 719)
        top_3 = float(both_err[2] <= 1) + float(both_err[2] == 719)
        top_1_hr = float(torch.sum(max_h == hr))
        top_1_min = float(minute_err <= 1) + float(minute_err == 59)

        update_eval_log(eval_log, top_1, top_2, top_3, top_1_hr, top_1_min, int(iou50))
        img = einops.rearrange(img[0], 'c h w -> h w c').cpu().numpy()[:,:,::-1] * 255 
        img_ = einops.rearrange(img_[0], 'c h w -> h w c').cpu().numpy()[:,:,::-1] * 255

        #uncomment this to save image
        #os.makedirs('../viz/{}/{}'.format(verbose,names[i]), exist_ok=True)
        #if idx < 100:
        #cv2.imwrite('../viz/{}/{}/{}_{}_{}.png'.format(verbose,names[i],idx, int(max_pred[0]), int(hr*60+mn)), img)
        #cv2.imwrite('../viz/{}/{}/{}_w.png'.format(verbose,names[i],idx), img_)

    print_eval_log(eval_log, i)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--verbose', type=str, default='full+++')

    args = parser.parse_args(args=[])
    main(args)




KeyboardInterrupt: 

### train


In [None]:
import os
import time
import torch
import einops
import random
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
from tensorboardX import SummaryWriter
from datetime import datetime
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from data import *
from utils import warp, update_train_log, write_train_log, update_eval_log, write_eval_log

def main(args):
    bsz = 32
    lr = 1e-4
    verbose = args.verbose
    use_stn = not args.no_stn
    augment = not args.no_augment
    homography = not args.no_homography
    artefacts = not args.no_artefacts

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dt_string = datetime.now().strftime("%m_%d_%H_%M")
    writer = SummaryWriter(logdir='../logs/{}-{}'.format(dt_string, verbose))

    # DATASET
    trn_dataset = ClockSyn(augment=augment, use_homography=homography, use_artefacts=artefacts)
    coco_dataset = ClockEval('coco')
    openimg_dataset = ClockEval('openimages')
    trn_loader = DataLoader(trn_dataset, batch_size=bsz, shuffle=True, num_workers=4)
    coco_loader = DataLoader(coco_dataset, batch_size=1, shuffle=False, num_workers=4)
    openimg_loader = DataLoader(openimg_dataset, batch_size=1, shuffle=False, num_workers=4)

    # MODEL
    model_stn = models.resnet50(pretrained=True)
    model_stn.fc = nn.Linear(2048, 8)
    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(2048, 720)
    if args.resume_path:
        model.load_state_dict(torch.load(args.resume_path))
        model_stn.load_state_dict(torch.load(args.stn_resume_path))
    model_stn.to(device)
    model.to(device)

    # OPTIMIZER
    optimizer = optim.Adam(list(model.parameters()) + list(model_stn.parameters()), lr=lr)
    cross_entropy = torch.nn.CrossEntropyLoss()
    scaler = GradScaler()

    for ep in range(40):
        print('Epoch {}'.format(ep))
        train_log = {'loss_cls': [], 'loss_reg': [], 'hour_acc': [], 'minute_acc': []}

        for i, trn_sample in enumerate(trn_loader):
            model.train()
            model_stn.train()
            optimizer.zero_grad()

            img, hour, minute, Minv = trn_sample
            img = img.float().to(device)
            Minv = Minv.to(device)
            hour = hour.type(torch.long).to(device)
            minute = minute.type(torch.long).to(device)

            # PREDICT
            with autocast():
                if use_stn:
                    pred_st = model_stn(img)
                    pred_st = torch.cat([pred_st, torch.ones(bsz, 1).to(device)], 1)
                    Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
                    img_ = warp(img, Minv_pred)
                    if random.random() < 0.5:
                        pred = model(img_)
                    else:
                        pred = model(img)
                    loss_reg = torch.mean(torch.abs((Minv.reshape(bsz, 9) - pred_st)))
                    loss_cls = cross_entropy(pred, hour * 60 + minute)
                else:
                    pred = model(img)
                    loss_cls = cross_entropy(pred, hour * 60 + minute)
                    loss_reg = 0.

                # LOSS
                loss = 100 * loss_reg + loss_cls

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # METRIC
            max_pred = torch.argsort(pred, dim=1, descending=True)
            max_pred = max_pred[:, 0]
            max_h = max_pred // 60
            max_m = max_pred % 60
            hour_acc = float(torch.sum(max_h == hour)) / bsz
            minute_acc = float(torch.sum(torch.abs(max_m - minute) <= 1)) / bsz

            update_train_log(train_log, loss_cls, loss_reg, hour_acc, minute_acc)

            if i % 10 == 0:  # Print every 10 batches
                print(f'Batch {i}/{len(trn_loader)} - Loss: {loss.item():.4f}, Hour Accuracy: {hour_acc:.4f}, Minute Accuracy: {minute_acc:.4f}')

            if i == 0:
                writer.add_images('train', img, ep)
                if use_stn:
                    writer.add_images('train_warped', img_, ep)
        write_train_log(writer, train_log, use_stn, ep)

        names = ['COCO', 'OpenImages'，'ClockMovies']
        for i, vloader in enumerate([coco_loader, openimg_loader]):
            eval_log = {'top_1': [], 'top_2': [], 'top_3': [], 'top_1_hr': [], 'top_1_min': [], 'iou50': []}
            imgs = []
            imgs_warped = []
            for idx, val_sample in enumerate(vloader):
                with torch.no_grad():
                    model.eval()
                    model_stn.eval()

                    img, hour, minute, iou50 = val_sample
                    img = img.float().to(device)
                    hr = hour.type(torch.long).to(device)
                    mn = minute.type(torch.long).to(device)

                    # MODEL
                    if use_stn:
                        pred_st = model_stn(img)
                        pred_st = torch.cat([pred_st, torch.ones(1, 1).to(device)], 1)
                        Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
                        img_ = warp(img, Minv_pred)
                        pred = model(img_)
                    else:
                        pred = model(img)

                    # Top 3 predictions
                    max_pred = torch.argsort(pred, dim=1, descending=True)
                    max_pred = max_pred[0, :3]
                    max_h = max_pred[0] // 60
                    max_m = max_pred[0] % 60

                    minute_err = torch.sum(torch.abs(max_m - mn))
                    both_err = torch.abs(max_pred - (hr * 60 + mn))
                    top_1 = float(both_err[0] <= 1) + float(both_err[0] == 719)
                    top_2 = float(both_err[1] <= 1) + float(both_err[1] == 719)
                    top_3 = float(both_err[2] <= 1) + float(both_err[2] == 719)
                    top_1_hr = float(torch.sum(max_h == hr))
                    top_1_min = float(minute_err <= 1) + float(minute_err == 59)

                    update_eval_log(eval_log, top_1, top_2, top_3, top_1_hr, top_1_min, int(iou50))

                    if idx < 64:
                        imgs.append(img[0])
                        if use_stn:
                            imgs_warped.append(img_[0])

            print(f'Evaluation on {names[i]} - Top 1 Accuracy: {sum(eval_log["top_1"]) / len(eval_log["top_1"]):.4f}')
            writer.add_images(names[i], torch.stack(imgs, 0), ep)
            if use_stn:
                writer.add_images(names[i] + '_warped', torch.stack(imgs_warped, 0), ep)
            write_eval_log(writer, eval_log, i, ep)

        torch.save(model.state_dict(), '../models/{}.pth'.format(verbose))
        if use_stn:
            torch.save(model_stn.state_dict(), '../models/{}_st.pth'.format(verbose))

if __name__ == "__main__":
    parser = ArgumentParser()
    # Ablations
    parser.add_argument('--no_augment', action='store_true')
    parser.add_argument('--no_homography', action='store_true')
    parser.add_argument('--no_artefacts', action='store_true')
    parser.add_argument('--no_stn', action='store_true')

    parser.add_argument('--verbose', type=str, default='base')
    parser.add_argument('--resume_path', type=str, default=None)
    parser.add_argument('--stn_resume_path', type=str, default=None)

    args = parser.parse_args(args=[])
    main(args)




Epoch 0
Batch 0/2500 - Loss: 53.9032, Hour Accuracy: 0.0938, Minute Accuracy: 0.0938
Batch 10/2500 - Loss: 39.9347, Hour Accuracy: 0.1562, Minute Accuracy: 0.0312
Batch 20/2500 - Loss: 26.2138, Hour Accuracy: 0.0312, Minute Accuracy: 0.0000
Batch 30/2500 - Loss: 22.4486, Hour Accuracy: 0.2188, Minute Accuracy: 0.0625
Batch 40/2500 - Loss: 21.5645, Hour Accuracy: 0.0625, Minute Accuracy: 0.0312
Batch 50/2500 - Loss: 19.2791, Hour Accuracy: 0.1562, Minute Accuracy: 0.0938
Batch 60/2500 - Loss: 16.4688, Hour Accuracy: 0.1875, Minute Accuracy: 0.0312
Batch 70/2500 - Loss: 18.2203, Hour Accuracy: 0.1562, Minute Accuracy: 0.0938
Batch 80/2500 - Loss: 16.9972, Hour Accuracy: 0.1875, Minute Accuracy: 0.0000
Batch 90/2500 - Loss: 16.9511, Hour Accuracy: 0.1250, Minute Accuracy: 0.0000
Batch 100/2500 - Loss: 14.4926, Hour Accuracy: 0.1562, Minute Accuracy: 0.0312
Batch 110/2500 - Loss: 14.3264, Hour Accuracy: 0.1875, Minute Accuracy: 0.0938
Batch 120/2500 - Loss: 15.4108, Hour Accuracy: 0.1875, 

Batch 1040/2500 - Loss: 8.4865, Hour Accuracy: 0.3438, Minute Accuracy: 0.1875
Batch 1050/2500 - Loss: 9.8708, Hour Accuracy: 0.4375, Minute Accuracy: 0.1562
Batch 1060/2500 - Loss: 8.7159, Hour Accuracy: 0.4062, Minute Accuracy: 0.2812
Batch 1070/2500 - Loss: 8.8095, Hour Accuracy: 0.5312, Minute Accuracy: 0.1875
Batch 1080/2500 - Loss: 10.2803, Hour Accuracy: 0.3750, Minute Accuracy: 0.2812
Batch 1090/2500 - Loss: 9.4814, Hour Accuracy: 0.4062, Minute Accuracy: 0.1875
Batch 1100/2500 - Loss: 8.3335, Hour Accuracy: 0.5312, Minute Accuracy: 0.2500
Batch 1110/2500 - Loss: 9.8326, Hour Accuracy: 0.5938, Minute Accuracy: 0.2188
Batch 1120/2500 - Loss: 8.8084, Hour Accuracy: 0.5312, Minute Accuracy: 0.1875
Batch 1130/2500 - Loss: 8.2571, Hour Accuracy: 0.4062, Minute Accuracy: 0.3438
Batch 1140/2500 - Loss: 8.7180, Hour Accuracy: 0.4375, Minute Accuracy: 0.2812
Batch 1150/2500 - Loss: 8.9013, Hour Accuracy: 0.3750, Minute Accuracy: 0.3125
Batch 1160/2500 - Loss: 8.5571, Hour Accuracy: 0.37

Batch 2080/2500 - Loss: 5.9176, Hour Accuracy: 0.6562, Minute Accuracy: 0.5000
Batch 2090/2500 - Loss: 7.6817, Hour Accuracy: 0.4688, Minute Accuracy: 0.4375
Batch 2100/2500 - Loss: 6.7606, Hour Accuracy: 0.6562, Minute Accuracy: 0.5000
Batch 2110/2500 - Loss: 7.3670, Hour Accuracy: 0.5938, Minute Accuracy: 0.3750
Batch 2120/2500 - Loss: 5.8851, Hour Accuracy: 0.7188, Minute Accuracy: 0.5938
Batch 2130/2500 - Loss: 6.7397, Hour Accuracy: 0.5312, Minute Accuracy: 0.4062
Batch 2140/2500 - Loss: 6.9214, Hour Accuracy: 0.7812, Minute Accuracy: 0.5625
Batch 2150/2500 - Loss: 7.1479, Hour Accuracy: 0.7188, Minute Accuracy: 0.5625
Batch 2160/2500 - Loss: 7.2542, Hour Accuracy: 0.5000, Minute Accuracy: 0.2812
Batch 2170/2500 - Loss: 7.9994, Hour Accuracy: 0.4688, Minute Accuracy: 0.3438
Batch 2180/2500 - Loss: 7.5007, Hour Accuracy: 0.4688, Minute Accuracy: 0.4375
Batch 2190/2500 - Loss: 7.0131, Hour Accuracy: 0.7188, Minute Accuracy: 0.5625
Batch 2200/2500 - Loss: 5.9045, Hour Accuracy: 0.593

Batch 620/2500 - Loss: 5.5818, Hour Accuracy: 0.8438, Minute Accuracy: 0.5312
Batch 630/2500 - Loss: 6.2032, Hour Accuracy: 0.7812, Minute Accuracy: 0.5312
Batch 640/2500 - Loss: 6.4402, Hour Accuracy: 0.7500, Minute Accuracy: 0.6562
Batch 650/2500 - Loss: 6.7349, Hour Accuracy: 0.6875, Minute Accuracy: 0.4375
Batch 660/2500 - Loss: 5.6140, Hour Accuracy: 0.8438, Minute Accuracy: 0.6250
Batch 670/2500 - Loss: 5.5301, Hour Accuracy: 0.7812, Minute Accuracy: 0.5312
Batch 680/2500 - Loss: 6.2650, Hour Accuracy: 0.7188, Minute Accuracy: 0.5625
Batch 690/2500 - Loss: 5.4765, Hour Accuracy: 0.7188, Minute Accuracy: 0.6250
Batch 700/2500 - Loss: 5.9461, Hour Accuracy: 0.7500, Minute Accuracy: 0.6875
Batch 710/2500 - Loss: 7.1229, Hour Accuracy: 0.6562, Minute Accuracy: 0.5938
Batch 720/2500 - Loss: 5.4681, Hour Accuracy: 0.9062, Minute Accuracy: 0.7500
Batch 730/2500 - Loss: 6.5262, Hour Accuracy: 0.9062, Minute Accuracy: 0.7188
Batch 740/2500 - Loss: 6.2379, Hour Accuracy: 0.7500, Minute Acc

Batch 1670/2500 - Loss: 5.1963, Hour Accuracy: 0.7500, Minute Accuracy: 0.8125
Batch 1680/2500 - Loss: 5.7830, Hour Accuracy: 0.7500, Minute Accuracy: 0.4062
Batch 1690/2500 - Loss: 5.5255, Hour Accuracy: 1.0000, Minute Accuracy: 0.8125
Batch 1700/2500 - Loss: 5.9635, Hour Accuracy: 0.8438, Minute Accuracy: 0.5625
Batch 1710/2500 - Loss: 5.0869, Hour Accuracy: 0.9062, Minute Accuracy: 0.5938
Batch 1720/2500 - Loss: 5.8914, Hour Accuracy: 0.7812, Minute Accuracy: 0.5000
Batch 1730/2500 - Loss: 5.6353, Hour Accuracy: 0.8125, Minute Accuracy: 0.5312
Batch 1740/2500 - Loss: 6.2918, Hour Accuracy: 0.7500, Minute Accuracy: 0.5000
Batch 1750/2500 - Loss: 6.1363, Hour Accuracy: 0.7812, Minute Accuracy: 0.5312
Batch 1760/2500 - Loss: 4.9717, Hour Accuracy: 0.9688, Minute Accuracy: 0.8750
Batch 1770/2500 - Loss: 5.4347, Hour Accuracy: 0.8438, Minute Accuracy: 0.6875
Batch 1780/2500 - Loss: 5.1762, Hour Accuracy: 0.7812, Minute Accuracy: 0.5938
Batch 1790/2500 - Loss: 6.2866, Hour Accuracy: 0.781

Batch 200/2500 - Loss: 4.9379, Hour Accuracy: 0.9062, Minute Accuracy: 0.7188
Batch 210/2500 - Loss: 4.2310, Hour Accuracy: 0.8438, Minute Accuracy: 0.7812
Batch 220/2500 - Loss: 4.5686, Hour Accuracy: 0.9062, Minute Accuracy: 0.6875
Batch 230/2500 - Loss: 4.9128, Hour Accuracy: 0.8125, Minute Accuracy: 0.6562
Batch 240/2500 - Loss: 4.6539, Hour Accuracy: 0.7500, Minute Accuracy: 0.5938
Batch 250/2500 - Loss: 4.1944, Hour Accuracy: 0.9688, Minute Accuracy: 0.7812
Batch 260/2500 - Loss: 5.3547, Hour Accuracy: 0.7812, Minute Accuracy: 0.6250
Batch 270/2500 - Loss: 4.8076, Hour Accuracy: 0.9062, Minute Accuracy: 0.6562
Batch 280/2500 - Loss: 4.5734, Hour Accuracy: 0.8438, Minute Accuracy: 0.8750
Batch 290/2500 - Loss: 4.2426, Hour Accuracy: 0.9375, Minute Accuracy: 0.8750
Batch 300/2500 - Loss: 5.7592, Hour Accuracy: 0.8750, Minute Accuracy: 0.6562
Batch 310/2500 - Loss: 4.5656, Hour Accuracy: 0.9062, Minute Accuracy: 0.9062


### label


In [38]:
import os
import cv2
from PIL import Image
import logging

def create_folder_from_video(video_path):
    # 获取视频文件名（不带扩展名）
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    
    # 创建同名文件夹
    folder_path = os.path.join(os.path.dirname(video_path), video_name)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    return folder_path

def extract_keyframes(video_path, num_frames=10):
    # 打开视频文件
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_interval = max(frame_count // num_frames, 1)
    
    frames = []
    for i in range(0, frame_count, frame_interval):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frames.append(frame)
        if len(frames) >= num_frames:
            break
    
    cap.release()
    return frames

def save_frames_as_images(frames, folder_path):
    for idx, frame in enumerate(frames):
        # 将帧转换为PIL图像
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        img_path = os.path.join(folder_path, f"{idx + 1}.png")
        img.save(img_path)

def main(video_path):
    # 创建同名文件夹
    folder_path = create_folder_from_video(video_path)
    
    # 从视频中抽取关键帧
    frames = extract_keyframes(video_path)
    
    # 保存关键帧为图片
    save_frames_as_images(frames, folder_path)
    
    print(f"关键帧已保存到文件夹: {folder_path}")


base_dir = 'C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1/'
vids = natsorted([x for x in os.listdir(base_dir) if '-' in x])
#     print([x for x in os.listdir(base_dir) if '-' in x])
    
for vid in vids:
    extractKeyFrames(base_dir + vid)
    


关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept (1)
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept (1)
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage--k-day-to-night-time-lapse-at-the-fish-mar

关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage-timelapse-of-a-blue-clock-on-a-white-wall-the-clock-starts-ticking-at-and-ends-at
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage-vintage-aged-pocket-watch-clock-timelapse-time-flowing-speed-analog-clock-motion
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage-wall-clock-show-the-running-time-time-lapse-on-a-modern-wall-clock-at-noon-close-up-to-a-wall
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage-wall-clock-show-the-running-time-time-lapse-on-a-modern-wall-clock-close-up-to-a-wall-clock
关键帧已保存到文件夹: C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1\stock-footage-wall-clock-show-the-running-time-time-lapse-on-a-modern-white-wall-clock-close-up

In [2]:
import os
import time
import torch
import einops
import random
import cv2
import numpy as np
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from utils import warp
from natsort import natsorted
from cyclic_ransac import RANSACRegressor
import logging

def main(args):
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    logger.info(f'Device set to {device}')

    # MODEL
    model_stn = models.resnet50(pretrained=True)
    model_stn.fc = nn.Linear(2048, 8)
    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(2048, 720)
    
    base_dir = 'C:/Users/Administrator/Desktop/itsabouttime-main/itsabouttime-main/data/filtered_data/1/'
    resume_path = f'../models/{args.verbose}.pth'
    stn_resume_path = f'../models/{args.verbose}_st.pth'
    
    logger.info(f'Loading model state from {resume_path}')
    model.load_state_dict(torch.load(resume_path))
    logger.info(f'Loading STN model state from {stn_resume_path}')
    model_stn.load_state_dict(torch.load(stn_resume_path))
    
    model.eval()
    model_stn.eval()
    
    model.to(device)
    model_stn.to(device)
    
    min_vid_length = 5
    min_range = 20
    score_threshold = 0.7
    ransac_threshold = 3

    os.makedirs('../data/labels', exist_ok=True)
    os.makedirs('../plots/pos', exist_ok=True)
    os.makedirs('../plots/neg', exist_ok=True)
    
#     if not args.no_save:
#         label_path = f'../data/labels/{args.verbose}.txt'
#         if os.path.isfile(label_path):
#             logger.info(f'Removing existing label file {label_path}')
#             os.remove(label_path)
    
    label_path = f'../data/labels/{args.verbose}.txt'
#     if os.path.isfile(label_path):
#         logger.info(f'Removing existing label file {label_path}')
#         os.remove(label_path)
    
    vids = natsorted([x for x in os.listdir(base_dir) if '-' in x])
#     print([x for x in os.listdir(base_dir) if '-' in x])
    logger.info(f'Found {len(vids)} videos for processing')
    
    
#     for vid in vids:
#         extractKeyFrames(base_dir + vid)
    
    
    for vid in vids:
#         print(vid)
        logger.info(f'Processing video: {vid}')
        base_name = os.path.splitext(vid)[0]
        imgs = natsorted([x for x in os.listdir(base_dir + base_name) if '.png' in x])
        ids = [int(x.strip('.png')) for x in imgs]
        
        print(len(imgs))
        print(min_vid_length)
        print(len(imgs) > min_vid_length)
        
        
        if len(imgs) > min_vid_length:
            length = len(imgs)
            frame_gap = len(imgs) // 100 + 1
            imgs = imgs[::frame_gap]
            logger.debug(f'Selected {len(imgs)} frames from video')
            
            data = []
            for i in imgs:
                img = cv2.imread(base_dir + vid + '/' + i)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (224, 224))
                img = einops.rearrange(img, 'h w c -> c h w') / 255.
                data.append(img)
            data = np.stack(data, 0)
            
            data = torch.Tensor(data).float().to(device)
            
            pred_st = model_stn(data)
            pred_st = torch.cat([pred_st, torch.ones(len(imgs), 1).to(device)], 1)
            Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
            data_ = warp(data, Minv_pred, sz=224)
            
            pred = model(data_)
            max_pred = torch.argsort(pred, dim=1, descending=True)
            max_pred = max_pred[:, 0].detach().cpu().numpy()
            
            X = np.array([int(x.strip('.png')) for x in imgs]).reshape(-1, 1)
            y = max_pred
            ransac = RANSACRegressor(residual_threshold=ransac_threshold, stop_probability=0.999)
            ransac.fit(X, y)
            inlier_mask = ransac.inlier_mask_
            outlier_mask = np.logical_not(inlier_mask)
            line_X = np.array(ids).reshape(-1, 1)
            line_y_ransac = ransac.predict(line_X)
            line_y_plot = ransac.predict(X)
            score = np.sum(inlier_mask) / len(imgs)
            logger.debug(f'RANSAC score: {score}')
            
            if (score > score_threshold) and (np.max(line_y_ransac) - np.min(line_y_ransac) > min_range):
                valid = True
                logger.info(f'Video {vid} is valid')
            else:
                valid = False
                logger.info(f'Video {vid} is not valid')
            
            if args.plot:
                plt.plot()
                plt.scatter(X[inlier_mask], y[inlier_mask], color="yellowgreen", marker=".", label="Inliers")
                plt.scatter(X[outlier_mask], y[outlier_mask], color="gold", marker=".", label="Outliers")
                plt.plot(line_X, line_y_ransac, color="cornflowerblue", linewidth=2, label="RANSAC regressor")
                folder = 'pos' if valid else 'neg'
                plt.savefig(f'../plots/{folder}/{vid}.png')
                plt.close()
                img_ = cv2.imread(f'../plots/{folder}/{vid}.png')
                H, W, _ = np.shape(img_)
                if valid:
                    cv2.rectangle(img_, (0, 0), (W, H), (0, 255, 0), 20)
                else:
                    cv2.rectangle(img_, (0, 0), (W, H), (0, 0, 255), 20)
                x = int(len(imgs) // 7)
                for i in [0, x, 2*x, 3*x, 4*x, 5*x, 6*x, len(imgs)-1]:
                    img_i = cv2.imread(base_dir + vid + '/' + imgs[i])
                    img_i = cv2.resize(img_i, (W, H))
                    if abs(line_y_plot[i] - y[i]) % 720 <= 3:
                        cv2.rectangle(img_i, (0, 0), (W, H), (0, 255, 0), 20)
                    else:
                        cv2.rectangle(img_i, (0, 0), (W, H), (0, 0, 255), 20)
                    img_ = np.concatenate([img_, np.ones([H, 40, 3])*255, img_i], 1)
                cv2.imwrite(f'../plots/{folder}/{vid}.png', img_)
            
            if valid and not args.no_save:
                print('hello')
                with open(label_path, 'a') as f:
                    f.write(vid + ',' + str(list(np.rint(line_y_ransac).astype(int))))
                    f.write('\n')

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--verbose', type=str, default='base')
    parser.add_argument('--plot', action='store_true')
    parser.add_argument('--no_save', action='store_true')
    args = parser.parse_args(args=[])
    main(args)


2024-05-16 03:43:18,640 - INFO - Device set to cuda
2024-05-16 03:43:19,541 - INFO - Loading model state from ../models/base.pth
2024-05-16 03:43:19,727 - INFO - Loading STN model state from ../models/base_st.pth
2024-05-16 03:43:19,973 - INFO - Found 78 videos for processing
2024-05-16 03:43:19,973 - INFO - Processing video: stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept
2024-05-16 03:43:19,974 - DEBUG - Selected 10 frames from video


10
5
True


2024-05-16 03:43:20,602 - DEBUG - RANSAC score: 0.6
2024-05-16 03:43:20,602 - INFO - Video stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept is not valid
2024-05-16 03:43:20,603 - INFO - Processing video: stock-footage--k-animation-of-clock-arrows-walking-fast-clockwise-over-white-backgrount-time-passing-concept.mp4
2024-05-16 03:43:20,604 - DEBUG - Selected 10 frames from video


10
5
True


error: OpenCV(4.9.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'


### refine

In [26]:
import os
import time
import torch
import einops
import random
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
from tensorboardX import SummaryWriter
from datetime import datetime
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from data import *
from utils import warp, update_train_log, write_train_log, update_eval_log, write_eval_log

def main(args):
  bsz = 32
  lr = 1e-4
  verbose = args.verbose
  use_stn = not args.no_stn

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  dt_string = datetime.now().strftime("%m_%d_%H_%M")
  writer = SummaryWriter(logdir='../logs/{}-{}'.format(dt_string, verbose))

  # DATASET
  trn_dataset = ClockSyn(augment=True, use_homography=True, use_artefacts=True)
  timelapse_dataset = ClockTimelapse('../data/labels/{}.txt'.format(args.verbose), augment=True)
  coco_dataset = ClockEval('coco')
  openimg_dataset = ClockEval('openimages')
  movie_dataset = ClockEval('clockmovies')
  trn_loader = DataLoader(trn_dataset, batch_size=bsz, shuffle=True)
  timelapse_loader = DataLoader(timelapse_dataset, batch_size=bsz, shuffle=True, drop_last=True)
  coco_loader = DataLoader(coco_dataset, batch_size=1, shuffle=False)
  openimg_loader = DataLoader(openimg_dataset, batch_size=1, shuffle=False)
  movie_loader = DataLoader(movie_dataset, batch_size=1, shuffle=False)

  # MODEL
  model_stn = models.resnet50(pretrained=True)
  model_stn.fc = nn.Linear(2048, 8)
  model = models.resnet50(pretrained=True)
  model.fc = nn.Linear(2048, 720)
  resume_path = '../models/{}.pth'.format(args.verbose)
  stn_resume_path = '../models/{}_st.pth'.format(args.verbose)
  model.load_state_dict(torch.load(resume_path))
  model_stn.load_state_dict(torch.load(stn_resume_path))
  model_stn.to(device)
  model.to(device)

  #OPTIM
  optimizer = optim.Adam(list(model.parameters()) + list(model_stn.parameters()), lr=lr)
  cross_entropy = torch.nn.CrossEntropyLoss()

  for ep in range(0):
    print('Epoch {}'.format(ep))
    train_log = {'loss_cls': [], 'loss_reg': [], 'hour_acc': [], 'minute_acc': []}

    for i, trn_sample in enumerate(zip(trn_loader, timelapse_loader)):
      model.train()
      model_stn.train()
      optimizer.zero_grad()

      img, hour, minute, Minv = trn_sample[0]
      img2, hour2, minute2 = trn_sample[1]

      img = torch.cat([img, img2], 0)
      hour = torch.cat([hour, hour2], 0)
      minute = torch.cat([minute, minute2], 0)

      img = img.float().to(device)
      Minv = Minv.to(device)
      hour = hour.type(torch.long).to(device)
      minute = minute.type(torch.long).to(device)

      # PREDICT
      if use_stn:
        pred_st = model_stn(img)
        pred_st = torch.cat([pred_st,torch.ones(bsz*2,1).to(device)], 1)
        Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
        img_ = warp(img, Minv_pred)
        if random.random() < 0.5:
          pred = model(img_)
        else:
          pred = model(img)
        loss_reg = torch.mean(torch.abs((Minv.reshape(bsz,9) - pred_st[:bsz])))
        loss_cls = cross_entropy(pred, hour * 60 + minute)
      else:
        pred = model(img)
        loss_cls = cross_entropy(pred, hour * 60 + minute)
        loss_reg = 0.

      # LOSS
      loss = 100 * loss_reg + loss_cls
      loss.backward()
      optimizer.step()

      # METRIC
      max_pred = torch.argsort(pred, dim=1, descending=True)  
      max_pred = max_pred[:,0]
      max_h = max_pred // 60
      max_m = max_pred % 60
      hour_acc = float(torch.sum(max_h == hour)) / (2*bsz)
      minute_acc = float(torch.sum(torch.abs(max_m - minute) <= 1)) / (2*bsz)

      update_train_log(train_log, loss_cls, loss_reg, hour_acc, minute_acc)
      if i == 0:
        writer.add_images('train', img, ep)
        if use_stn: writer.add_images('train_warped', img_, ep)
    write_train_log(writer, train_log, use_stn, ep)

    names = ['COCO', 'OpenImages','ClockMovies']
    for i, vloader in enumerate([coco_loader, openimg_loader, movie_loader]):
      eval_log = {'top_1': [],'top_2': [],'top_3': [],'top_1_hr': [], 'top_1_min': [], 'iou50': []}
      imgs = []
      imgs_warped = []
      for idx, val_sample in enumerate(vloader):
        with torch.no_grad():
          model.eval()
          model_stn.eval()
          
          img, hour, minute, iou50 = val_sample
          img = img.float().to(device)
          hr = hour.type(torch.long).to(device)
          mn = minute.type(torch.long).to(device)

          #MODEL
          if use_stn:
            pred_st = model_stn(img)
            pred_st = torch.cat([pred_st,torch.ones(1,1).to(device)], 1)
            Minv_pred = torch.reshape(pred_st, (-1, 3, 3))
            img_ = warp(img, Minv_pred)
            pred = model(img_)
          else:
            pred = model(img)  

          #top 3 predictions
          max_pred = torch.argsort(pred, dim=1, descending=True)
          max_pred = max_pred[0,:3]
          max_h = max_pred[0] // 60
          max_m = max_pred[0] % 60

          minute_err = torch.sum(torch.abs(max_m - mn))
          both_err = torch.abs(max_pred - (hr * 60 + mn))
          top_1 = float(both_err[0] <= 1) + float(both_err[0] == 719)
          top_2 = float(both_err[1] <= 1) + float(both_err[1] == 719)
          top_3 = float(both_err[2] <= 1) + float(both_err[2] == 719)
          top_1_hr = float(torch.sum(max_h == hr))
          top_1_min = float(minute_err <= 1) + float(minute_err == 59)

          update_eval_log(eval_log, top_1, top_2, top_3, top_1_hr, top_1_min, int(iou50))

          if idx < 64:
            imgs.append(img[0])
            if use_stn: imgs_warped.append(img_[0])        
      writer.add_images(names[i], torch.stack(imgs,0), ep)
      if use_stn: writer.add_images(names[i]+'_warped', torch.stack(imgs_warped,0), ep)
      write_eval_log(writer, eval_log, i, ep)

    torch.save(model.state_dict(), '../models/{}+.pth'.format(verbose))
    if use_stn: torch.save(model_stn.state_dict(), '../models/{}+_st.pth'.format(verbose))

if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_argument('--no_stn', action='store_true')
    parser.add_argument('--verbose', type=str, default='base')
    
    args = parser.parse_args(args=[])
    main(args)


FileNotFoundError: [Errno 2] No such file or directory: '../data/labels/base.txt'