In [1]:
import torch
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400):
    """Loss function defined over sequence of flow predictions"""

    if gamma > 1:
        raise ValueError(f"Gamma should be < 1, got {gamma}.")

    # exlude invalid pixels and extremely large diplacements
    flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt()
    valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)

    valid_flow_mask = valid_flow_mask[:, None, :, :]

    flow_preds = torch.stack(flow_preds)  # shape = (num_flow_updates, batch_size, 2, H, W)

    abs_diff = (flow_preds - flow_gt).abs()
    abs_diff = (abs_diff * valid_flow_mask).mean(axis=(1, 2, 3, 4))

    num_predictions = flow_preds.shape[0]
    weights = gamma ** torch.arange(num_predictions - 1, -1, -1).to(flow_gt.device)
    flow_loss = (abs_diff * weights).sum()

    return flow_loss

In [3]:
def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None):

    epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt()
    flow_norm = (flow_gt ** 2).sum(dim=1).sqrt()

    if valid_flow_mask is not None:
        epe = epe[valid_flow_mask]
        flow_norm = flow_norm[valid_flow_mask]

    relative_epe = epe / flow_norm

    metrics = {
        "epe": epe.mean().item(),
        "1px": (epe < 1).float().mean().item(),
        "3px": (epe < 3).float().mean().item(),
        "5px": (epe < 5).float().mean().item(),
        "f1": ((epe > 3) & (relative_epe > 0.05)).float().mean().item() * 100,
    }
    return metrics, epe.numel()

In [4]:
import torchvision.transforms.functional as F
class SameRandomCrop(torchvision.transforms.RandomCrop):
    def __init__(self, size):
        super().__init__(size)
        
    def forward(self, img1, img2, flow):
        i, j, h, w = self.get_params(img1, self.size)
        
        return F.crop(img1, i, j, h, w), F.crop(img2, i, j, h, w), F.crop(flow, i, j, h, w)

In [5]:
class TSintel(torchvision.datasets.Sintel):
    def __init__(self, root):
        super().__init__(root=root)
        self.crop = SameRandomCrop((368, 768))
        self.toTensor = torchvision.transforms.ToTensor()
        
    def __getitem__(self, index):
        img1, img2, flow = super().__getitem__(index)
        img1 = self.toTensor(img1)
        img2 = self.toTensor(img2)
        flow = torch.from_numpy(flow)
        img1, img2, flow = self.crop(img1, img2, flow)
        valid_flow_mask = torch.ones(flow.shape[1:]).to(torch.bool)
        return img1, img2, flow, valid_flow_mask

In [6]:
data = TSintel(".")

In [7]:
data[0][3]

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [8]:
train_size = round(len(data) * 0.8)
test_size = round(len(data) * 0.2)

In [9]:
train_data, test_data = torch.utils.data.random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(42))

In [10]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False)

In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [12]:
model = torchvision.models.optical_flow.raft_small()
model.to(device)

RAFT(
  (feature_encoder): FeatureEncoder(
    (convnormrelu): ConvNormActivation(
      (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
    )
    (layer1): Sequential(
      (0): BottleneckBlock(
        (convnormrelu1): ConvNormActivation(
          (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (1): InstanceNorm2d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): ReLU(inplace=True)
        )
        (convnormrelu2): ConvNormActivation(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): InstanceNorm2d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): ReLU(inplace=True)
        )
        (convnormrelu3): ConvNormActivation(
          (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (1): InstanceNorm

In [13]:
lr = 2e-5
epochs = 50
weight_decay = 5e-5
eps = 1e-8
num_train_flow_updates = 12

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.05,
    cycle_momentum=False,
    anneal_strategy="linear",
)

In [16]:
for epoch in range(epochs):
    
    for i, data_blob in enumerate(train_loader):
        optimizer.zero_grad()

        image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
        flow_predictions = model(image1, image2, num_flow_updates=num_train_flow_updates)

        loss = sequence_loss(flow_predictions, flow_gt, valid_flow_mask)
        metrics, epe = compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)

        metrics.pop("f1")

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        optimizer.step()
        scheduler.step()

0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
0 11
0 12
0 13
0 14
0 15
0 16
0 17
0 18
0 19
0 20
0 21
0 22
0 23
0 24
0 25
0 26
0 27
0 28
0 29
0 30
0 31
0 32
0 33
0 34
0 35
0 36
0 37
0 38
0 39
0 40
0 41
0 42
0 43
0 44
0 45
0 46
0 47
0 48
0 49
0 50
0 51
0 52
0 53
0 54
0 55
0 56
0 57
0 58
0 59
0 60
0 61
0 62
0 63
0 64
0 65
0 66
0 67
0 68
0 69
0 70
0 71
0 72
0 73
0 74
0 75
0 76
0 77
0 78
0 79
0 80
0 81
0 82
0 83
0 84
0 85
0 86
0 87
0 88
0 89
0 90
0 91
0 92
0 93
0 94
0 95
0 96
0 97
0 98
0 99
0 100
0 101
0 102
0 103
0 104
0 105
0 106
0 107
0 108
0 109
0 110
0 111
0 112
0 113
0 114
0 115
0 116
0 117
0 118
0 119
0 120
0 121
0 122
0 123
0 124
0 125
0 126
0 127
0 128
0 129
0 130
0 131
0 132
0 133
0 134
0 135
0 136
0 137
0 138
0 139
0 140
0 141
0 142
0 143
0 144
0 145
0 146
0 147
0 148
0 149
0 150
0 151
0 152
0 153
0 154
0 155
0 156
0 157
0 158
0 159
0 160
0 161
0 162
0 163
0 164
0 165
0 166
0 167
0 168
0 169
0 170
0 171
0 172
0 173
0 174
0 175
0 176
0 177
0 178
0 179
0 180
0 181
0 182
0 183
0 184


KeyboardInterrupt: 