<a href="https://colab.research.google.com/github/arifinnasif/Natural-Hazard-Prediction/blob/master/lightnet_on_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import torch.nn as nn
import torch
import random
from tqdm import tqdm

## ConvLSTM

In [21]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        # print("size of h_cur", h_cur.size())
        # print("size of combined", combined.size())

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module):

    """

    Parameters:
        input_dim: Number of channels in input
        hidden_dim: Number of hidden channels
        kernel_size: Size of kernel in convolutions
        num_layers: Number of LSTM layers stacked on each other
        batch_first: Whether or not dimension 0 is the batch or not
        bias: Bias or no bias in Convolution
        return_all_layers: Return the list of computations for all layers
        Note: Will do same padding.

    Input:
        A tensor of size B, T, C, H, W or T, B, C, H, W
    Output:
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
            0 - layer_output_list is the list of lists of length T of each output
            1 - last_state_list is the list of last states
                    each element of the list is a tuple (h, c) for hidden state and memory
    Example:
        >> x = torch.rand((32, 10, 64, 128, 128))
        >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
        >> _, last_states = convlstm(x)
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """

        Parameters
        ----------
        input_tensor: todo
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful

        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # Since the init is done in forward. Can send image size here
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [22]:
conv_lstm = ConvLSTM(1,[8],[(5,5)],1, batch_first=True) # 1 channel, 8 hidden channel (from geng), 5x5 kernel, 1 layer(?), batch first -> batch er size input tensor e first e dicci,
x = torch.rand((32, 6, 1, 25, 25)) # 32 samples in a batch (batch first), prev 6 hours, 1 channel, 25x25 grid
# convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
_, last_states = conv_lstm(x)
h,c = last_states[0]  # 0 for layer index, 0 for h index
# print(last_states[0][1].size())
# print("-----")
# print(h.size())


In [23]:
x_ = torch.rand((32, 6, 50, 50))
conv_ = nn.Conv2d(in_channels = 6, out_channels = 6*4, kernel_size = 5, groups = 6, stride = 2, padding = 2)
# conv_.weight.data = [None, None, ...].repeat(64, 1, 1, 1)
# print(conv_.weight.data.size())
conv_(x_).size()

torch.Size([32, 24, 25, 25])

In [24]:
class Encoder_old(nn.Module):
  def __init__(self):
    super(Encoder_old, self).__init__()
    self.prev_hours = 6
    self.conv_2 = nn.Conv2d(in_channels = self.prev_hours,
                          out_channels = self.prev_hours*4, # each input frame gets mapped to 4 layer
                          groups = self.prev_hours,
                          kernel_size = 7,
                          stride = 2,
                          padding = 3)
    self.conv_lstm = ConvLSTM(input_dim = 4,
                               hidden_dim = [8],
                               kernel_size = [(5,5)],
                               num_layers = 1,
                               batch_first=True)
    # print(self.conv_2.weight.data.size())

  def forward(self, input_tensor):
    x = self.conv_2(input_tensor.flatten(1,2))
    _, last_states = self.conv_lstm(torch.unflatten(x, dim = 1, sizes = (6, 4)))
    h,c = last_states[0]

    return h,c





In [25]:
class Encoder(nn.Module):
  def __init__(self, in_channels_for_a_given_time, prev_hours):
    super(Encoder, self).__init__()
    self.prev_hours = prev_hours
    self.hidden_dim = 8
    self.stride = 2
    self.in_channels_for_a_given_time = in_channels_for_a_given_time
    self.conv_2 = nn.Conv2d(in_channels = self.in_channels_for_a_given_time, # one 2d lightnig grid at time t
                          out_channels = 4, # from geng
                          kernel_size = 7,
                          stride = self.stride,
                          padding = 3)
    self.conv_lstm_cell = ConvLSTMCell(input_dim = 4,
                               hidden_dim = self.hidden_dim,
                               kernel_size = (5,5),
                               bias=True)
    # print(self.conv_2.weight.data.size())

  def forward(self, input_tensor):
    # x = self.conv_2(input_tensor.flatten(1,2))
    # _, last_states = self.conv_lstm(torch.unflatten(x, dim = 1, sizes = (6, 4)))
    # h,c = last_states[0]
    b, prev_hours, channels, height, width = input_tensor.size()
    h,c = self.init_hidden(batch_size=b, image_size=(height // self.stride, width // self.stride))

    for t in range(prev_hours):
      x = self.conv_2(input_tensor[:, t, :, :, :])
      h, c = self.conv_lstm_cell(x, cur_state=[h, c])
      return h,c

    #####

    #     layer_output_list = []
    #     last_state_list = []

    #     seq_len = input_tensor.size(1)
    #     cur_layer_input = input_tensor

    #     for layer_idx in range(self.num_layers):

    #         h, c = hidden_state[layer_idx]
    #         output_inner = []
    #         for t in range(seq_len):
    #             h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
    #                                              cur_state=[h, c])
    #             output_inner.append(h)

    #         layer_output = torch.stack(output_inner, dim=1)
    #         cur_layer_input = layer_output

    #         layer_output_list.append(layer_output)
    #         last_state_list.append([h, c])

    #     if not self.return_all_layers:
    #         layer_output_list = layer_output_list[-1:]
    #         last_state_list = last_state_list[-1:]

    #     return layer_output_list, last_state_list

    # return h,c

  def init_hidden(self, batch_size, image_size):
      height, width = image_size
      return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_2.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_2.weight.device))





In [26]:
class Fusion(nn.Module):
  def __init__(self, total_in_channels):
    super(Fusion,self).__init__()
    # h_stacked = torch.cat(h_list, dim = 1)
    # sc_stacked = torch.cat(c_list, dim = 1)
    self.conv_3 = nn.Conv2d(in_channels = total_in_channels,
                            out_channels = 64, # from the paper
                            kernel_size=1)
    self.conv_4 = nn.Conv2d(in_channels = total_in_channels,
                            out_channels = 64, # from the paper
                            kernel_size=1)

  def forward(self, h_list, c_list):
    h_stacked = torch.cat(h_list, dim = 1)
    c_stacked = torch.cat(c_list, dim = 1)


    h_fused_conved = self.conv_3(h_stacked)
    c_fused_conved = self.conv_4(c_stacked)

    return h_fused_conved, c_fused_conved


In [27]:
class Decoder(nn.Module):
  def __init__(self, in_channels_for_a_given_time, next_hours):
    super(Decoder, self).__init__()
    self.next_hours = next_hours
    self.in_channels_for_a_given_time = in_channels_for_a_given_time
    self.conv_5 = nn.Conv2d(in_channels = self.in_channels_for_a_given_time, # single L_{-1} frame
                          out_channels = 4, # from geng
                          kernel_size = 7,
                          stride = 2,
                          padding = 3)
    self.conv_lstm_cell = ConvLSTMCell(input_dim=4, # output of conv
                                       hidden_dim=64, # from geng
                                       kernel_size=(5,5), # from geng
                                       bias=True)
    # print(self.conv_2.weight.data.size())
    self.deconv = nn.ConvTranspose2d(in_channels=64, # as previous conv_lstm_cell had 64 hidden dim so the size is [:,64,:,:]
                                     out_channels=64, # from geng
                                     kernel_size = 7, # from geng
                                     stride = 2, # from geng
                                     padding = 3,
                                     output_padding = 1) # not sure

    self.conv_6 = nn.Conv2d(in_channels = 64, # previous deconv
                          out_channels = 1, # as it is just one layer
                          kernel_size = 1, # from geng
                          stride = 1) # from geng

  def forward(self, input_tensor_L_neg_1, h, c):
    # x = self.conv_2(input_tensor.flatten(1,2))
    # _, last_states = self.conv_lstm(torch.unflatten(x, dim = 1, sizes = (6, 4)))
    # h,c = last_states[0]
    b, _, height, width = input_tensor_L_neg_1.size()

    pred_output = []

    for i in range(self.next_hours):
      # print("---")
      x = self.conv_5(input_tensor_L_neg_1)
      # print(x.size())
      # print("dbg", x.size())
      h, c = self.conv_lstm_cell(x, cur_state=[h, c])
      # print(h.size())
      x = self.deconv(h)
      # print(x.size())
      x = self.conv_6(x)
      # print(x.size())

      pred_output.append(x)
      # print("---")

    pred_output = torch.stack(pred_output, 0).permute(1, 0, 2, 3, 4)
    # print(torch.stack(pred_output, 0).size())





    return pred_output



In [28]:
class LightNet(nn.Module):
  def __init__(self):
    super(LightNet, self).__init__()
    self.obs_enc = Encoder(in_channels_for_a_given_time=1,
                           prev_hours=6)
    self.fus = Fusion(8) # total h (or c) channels. as we are going to use just one h from the obs encoder (size of h = [:,8,:,:])
    self.pred_dec = Decoder(in_channels_for_a_given_time=1,
                            next_hours=6)

  def forward(self, input_tensor):
    h,c = self.obs_enc(input_tensor)
    h,c = self.fus([h],[c])
    out = self.pred_dec(input_tensor[:,-1,:,:,:], h, c)

    return out

In [29]:
# model = LightNet()
# x = torch.rand((32, 6, 1, 50, 50)) # [batch_size, prev_hours, input_layer, image_height, image_width]
# model(x)

In [30]:
def MeteorologicalMeasures(output, target):
  output, target = output.to(torch.device("cuda")), target.to(torch.device("cuda"))
  true_pos = 0
  true_neg = 0

  false_pos = 0
  false_neg = 0

  ytrue = target
  ypred = torch.sigmoid(output)
  ypred = torch.round(ypred)
  true_positives = torch.sum(ytrue * ypred)
  possible_positives = torch.sum(ytrue)
  POD = true_positives / (possible_positives + 1e-10)
  predicted_positives = torch.sum(ypred)
  FAR = true_positives / (predicted_positives + 1e-10)
  return POD, FAR



In [34]:
def train(model):
  data = torch.rand((1024,1,50,50), dtype=torch.double)
  data[data>0.5] = 1.0
  data[data<0.5] = 0.0
  # print(data)
  train_pct = 0.7
  prev_hours = 6
  next_hours = 6
  batch_size = 32
  start_index_list = list(range(data.size(0) - prev_hours - next_hours))
  random.shuffle(start_index_list)
  train_start_index_list = start_index_list[:int(len(start_index_list)*train_pct)]
  val_start_index_list = start_index_list[int(len(start_index_list)*train_pct):]
  # print("hello")

  # print(len(train_start_index_list))
  # print(len(val_start_index_list))

  ## train set creation
  train_x_set = []
  train_y_set = []
  for i in range(len(train_start_index_list) // batch_size):
    train_x_batch = []
    train_y_batch = []
    for j in range(batch_size):
      start = train_start_index_list[batch_size * i + j]
      train_x_batch.append(data[start:(start+prev_hours),:,:,:])
      train_y_batch.append(data[(start+prev_hours):(start+prev_hours+next_hours),:,:,:])

    train_x_batch = torch.stack(train_x_batch,0)
    train_y_batch = torch.stack(train_y_batch,0)

    train_x_set.append(train_x_batch)
    train_y_set.append(train_y_batch)

  # train_x_set = torch.stack(train_x_set,0)
  # train_y_set = torch.stack(train_y_set,0)
  # print(train_x_set.size())
  # print(train_y_set.size())

  ## validation set creation
  val_x_set = []
  val_y_set = []
  for i in range(len(val_start_index_list) // batch_size):
    val_x_batch = []
    val_y_batch = []
    for j in range(batch_size):
      start = val_start_index_list[batch_size * i + j]
      val_x_batch.append(data[start:(start+prev_hours),:,:,:])
      val_y_batch.append(data[(start+prev_hours):(start+prev_hours+next_hours),:,:,:])

    val_x_batch = torch.stack(val_x_batch,0)
    val_y_batch = torch.stack(val_y_batch,0)

    val_x_set.append(val_x_batch)
    val_y_set.append(val_y_batch)

  # val_x_set = torch.stack(val_x_set,0)
  # val_y_set = torch.stack(val_y_set,0)
  # print(val_x_set.size())
  # print(val_y_set.size())
  # print(data[start:(start+prev_hours+next_hours),:,:,:].size())

  model.double()
  # print(model(train_x_set[0]).size())

  ce_loss = nn.CrossEntropyLoss()

  optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
  epoch = 50

  for e in range(epoch):
    avg_train_loss = 0
    avg_val_loss = 0
    avg_train_POD = 0
    avg_train_FAR = 0
    avg_train_ETS = 0
    avg_val_POD = 0
    avg_val_FAR = 0
    avg_val_ETS = 0
    print("Epoch [",e+1,"/",epoch,"]", end=" ")
    for i in tqdm(range(len(train_x_set))):
      data, target = train_x_set[i].to(torch.device("cuda")), train_y_set[i].to(torch.device("cuda"))
      optimizer.zero_grad()
      # train
      model.train()
      output = model(data)
      train_loss = ce_loss(output, target)
      train_POD, train_FAR = MeteorologicalMeasures(output, target)
      avg_train_loss = avg_train_loss + train_loss.item()
      avg_train_POD = avg_train_POD + train_POD.item()
      avg_train_FAR = avg_train_FAR + train_FAR.item()
      # avg_train_ETS = avg_train_ETS + train_ETS
      train_loss.backward()
      optimizer.step()

    for i in range(len(val_x_set)):
      data, target = val_x_set[i].to(torch.device("cuda")), val_y_set[i].to(torch.device("cuda"))
      # val
      model.eval()
      output = model(data)
      val_loss = ce_loss(output, target)
      val_POD, val_FAR = MeteorologicalMeasures(output, target)
      avg_val_loss = avg_val_loss + val_loss.item()
      avg_val_POD = avg_val_POD + val_POD.item()
      avg_val_FAR = avg_val_FAR + val_FAR.item()
      # avg_val_ETS = avg_val_ETS + val_ETS
      # train_loss.backward()
      # optimizer.step()
    avg_train_loss = avg_train_loss / len(train_x_set)
    avg_train_POD = avg_train_POD / len(train_x_set)
    avg_train_FAR = avg_train_FAR / len(train_x_set)
    avg_train_ETS = avg_train_ETS / len(train_x_set)

    avg_val_loss = avg_val_loss / len(val_x_set)
    avg_val_POD = avg_val_POD / len(val_x_set)
    avg_val_FAR = avg_val_FAR / len(val_x_set)
    avg_val_ETS = avg_val_ETS / len(val_x_set)
    print("train loss", avg_train_loss,"|", "train POD", avg_train_POD,"|", "train FAR", avg_train_FAR)#,"|", "train ETS", avg_train_ETS)
    print("val loss", avg_val_loss,"|", "val POD", avg_val_POD,"|", "val FAR", avg_val_FAR)#,"|", "val ETS", avg_val_ETS)
    print()
      # if batch_idx % args.log_interval == 0:
      # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      #             epoch, i * len(data), len(train_loader.dataset),
      #             100. * batch_idx / len(train_loader), loss.item()))
      #         if args.dry_run:
      #             break
      # for i in range(opt.steps):
      #     print('STEP: ', i)
      #     def closure():
      #         optimizer.zero_grad()
      #         out = seq(input)
      #         loss = criterion(out, target)
      #         print('loss:', loss.item())
      #         loss.backward()
      #         return loss
      #     optimizer.step(closure)


model = LightNet()
model.to(torch.device('cuda'))
train(model)


Epoch [ 1 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.27s/it]


train loss 5.374536001660926 | train POD 0.0 | train FAR 0.0
val loss 5.371370665762395 | val POD 0.0 | val FAR 0.0

Epoch [ 2 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.37452006041462 | train POD 0.0 | train FAR 0.0
val loss 5.3713737655592455 | val POD 0.0 | val FAR 0.0

Epoch [ 3 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374512228648092 | train POD 0.0 | train FAR 0.0
val loss 5.371371584760227 | val POD 0.0 | val FAR 0.0

Epoch [ 4 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374503299822614 | train POD 0.0 | train FAR 0.0
val loss 5.37137514726955 | val POD 0.0 | val FAR 0.0

Epoch [ 5 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374492631377634 | train POD 0.0 | train FAR 0.0
val loss 5.371381357869897 | val POD 0.0 | val FAR 0.0

Epoch [ 6 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.25s/it]


train loss 5.374479314005023 | train POD 0.0 | train FAR 0.0
val loss 5.371392276550842 | val POD 0.0 | val FAR 0.0

Epoch [ 7 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374462517268706 | train POD 0.0 | train FAR 0.0
val loss 5.371411426895219 | val POD 0.0 | val FAR 0.0

Epoch [ 8 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.25s/it]


train loss 5.374442334923626 | train POD 0.0 | train FAR 0.0
val loss 5.371440229610749 | val POD 0.0 | val FAR 0.0

Epoch [ 9 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374418836200408 | train POD 5.1170235542311875e-06 | train FAR 0.36764069262215193
val loss 5.371470262555311 | val POD 1.2968763524026676e-05 | val FAR 0.5925925925779981

Epoch [ 10 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374391735155702 | train POD 0.00010438406854227688 | train FAR 0.5022473322401219
val loss 5.371507754254796 | val POD 0.00012090155815997607 | val FAR 0.4911045798654714

Epoch [ 11 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374358458048062 | train POD 0.0002840126503567963 | train FAR 0.5107636703285938
val loss 5.37155664594031 | val POD 0.0002927846440477724 | val FAR 0.5004123935804576

Epoch [ 12 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374324553858657 | train POD 0.00039161740791713874 | train FAR 0.5245263903412635
val loss 5.37159614091391 | val POD 0.00039885713798922006 | val FAR 0.5038083079096903

Epoch [ 13 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374288003973471 | train POD 0.0005609734866320551 | train FAR 0.521001036234371
val loss 5.371645179296817 | val POD 0.0005304355768493594 | val FAR 0.49696770420713793

Epoch [ 14 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374243518809343 | train POD 0.0006962780156483779 | train FAR 0.5245907967211347
val loss 5.371700359229387 | val POD 0.0009330369097470832 | val FAR 0.5028958118249651

Epoch [ 15 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374202673161323 | train POD 0.0008897625451044256 | train FAR 0.5159045065135234
val loss 5.371739952754486 | val POD 0.0008617001280474998 | val FAR 0.49717673905800264

Epoch [ 16 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374156111693195 | train POD 0.0012102311080128352 | train FAR 0.516107752736495
val loss 5.371776915675202 | val POD 0.0013467122456786316 | val FAR 0.4979339447038043

Epoch [ 17 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374108801288893 | train POD 0.0016524382446110332 | train FAR 0.5161104370141655
val loss 5.371792266761482 | val POD 0.002326565908936997 | val FAR 0.49803806691513697

Epoch [ 18 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.374055739287235 | train POD 0.0028812191476836366 | train FAR 0.5182550047486602
val loss 5.371799668668263 | val POD 0.0027569648593385467 | val FAR 0.49283188119867205

Epoch [ 19 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.25s/it]


train loss 5.374000627883327 | train POD 0.005414208105363811 | train FAR 0.5213082799281241
val loss 5.371787488875975 | val POD 0.0021968636315607784 | val FAR 0.4983024906212441

Epoch [ 20 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373945121940713 | train POD 0.008807268263554763 | train FAR 0.5171752727676403
val loss 5.371808534696444 | val POD 0.0030539973442269 | val FAR 0.4947177153595509

Epoch [ 21 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373890291772216 | train POD 0.008928411530757295 | train FAR 0.5173606406659036
val loss 5.371811031464265 | val POD 0.004783955872123664 | val FAR 0.49917723500305794

Epoch [ 22 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373835663557547 | train POD 0.01905491562710858 | train FAR 0.5151832182961319
val loss 5.37186822528794 | val POD 0.011513351106236414 | val FAR 0.4971937635527231

Epoch [ 23 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373776093139429 | train POD 0.0159997028153803 | train FAR 0.5183470986140365
val loss 5.37211925792755 | val POD 0.019562377611643485 | val FAR 0.4956146668688197

Epoch [ 24 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373697148602933 | train POD 0.01724218866106237 | train FAR 0.5173579423488883
val loss 5.372453395529043 | val POD 0.023256411466723883 | val FAR 0.4953869392603679

Epoch [ 25 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373617797046905 | train POD 0.01781746773537429 | train FAR 0.5188307208281651
val loss 5.3724627355539285 | val POD 0.018366905966621945 | val FAR 0.497057584364316

Epoch [ 26 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373508227749141 | train POD 0.019648610996814884 | train FAR 0.5198054652283006
val loss 5.37249170813271 | val POD 0.021959194919188418 | val FAR 0.4976353817406484

Epoch [ 27 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373397520478611 | train POD 0.025826903105683172 | train FAR 0.5191209254540591
val loss 5.372546402884374 | val POD 0.029129538061908995 | val FAR 0.49737968930930265

Epoch [ 28 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373277348984575 | train POD 0.031030821360251738 | train FAR 0.5187501567115627
val loss 5.372640332977829 | val POD 0.03581360445259519 | val FAR 0.4967855985928618

Epoch [ 29 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373146959961765 | train POD 0.035337473605510455 | train FAR 0.5188476373065174
val loss 5.372763203299125 | val POD 0.04096461991338762 | val FAR 0.49739252637934445

Epoch [ 30 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.373002960987025 | train POD 0.040356315209695265 | train FAR 0.5194968388250851
val loss 5.372903627086412 | val POD 0.04515473946238241 | val FAR 0.4973073359337352

Epoch [ 31 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.372845553638367 | train POD 0.04640603051268705 | train FAR 0.5195804256855326
val loss 5.373061413128179 | val POD 0.05061196656927712 | val FAR 0.49674342953180883

Epoch [ 32 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.372675983235054 | train POD 0.05345818256611334 | train FAR 0.5196030037770831
val loss 5.3732405331096595 | val POD 0.05871732754935058 | val FAR 0.49770476803358854

Epoch [ 33 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.37249500197921 | train POD 0.061454549679304346 | train FAR 0.5199198382060394
val loss 5.373444915386475 | val POD 0.07034255051559496 | val FAR 0.4975588915579473

Epoch [ 34 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.372304785792204 | train POD 0.07049458420047613 | train FAR 0.5195629872956867
val loss 5.373659130487561 | val POD 0.08299804809202564 | val FAR 0.49834090214895865

Epoch [ 35 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.372110226183288 | train POD 0.08067585416632012 | train FAR 0.51893566025604
val loss 5.373853640741394 | val POD 0.09100002807335311 | val FAR 0.4985336365850943

Epoch [ 36 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.3719114949727365 | train POD 0.09055686086031552 | train FAR 0.5180710070865593
val loss 5.374026270056113 | val POD 0.09551342519757193 | val FAR 0.49900179097139435

Epoch [ 37 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.371707050317193 | train POD 0.09969031567221878 | train FAR 0.5177044123244253
val loss 5.374189597320848 | val POD 0.09960310346624009 | val FAR 0.49882038717255167

Epoch [ 38 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.371502005640391 | train POD 0.1087043563679028 | train FAR 0.5176175811463934
val loss 5.374309607317129 | val POD 0.10387797212017404 | val FAR 0.49919033367941396

Epoch [ 39 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.371295211395789 | train POD 0.11668137851188008 | train FAR 0.517806272643738
val loss 5.374416394682445 | val POD 0.10837164813314372 | val FAR 0.49926246987531553

Epoch [ 40 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.371090643864185 | train POD 0.12392161286193115 | train FAR 0.5181725695673957
val loss 5.374537746536971 | val POD 0.11675825618904555 | val FAR 0.4993269178008581

Epoch [ 41 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.370867757648241 | train POD 0.12939455727720997 | train FAR 0.5185484150129768
val loss 5.374708181621015 | val POD 0.12255668681536257 | val FAR 0.49937994081416814

Epoch [ 42 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.370634937916654 | train POD 0.13733180593660418 | train FAR 0.5186813847451672
val loss 5.374964849515096 | val POD 0.12201241308706035 | val FAR 0.4994597948973411

Epoch [ 43 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.37043545569251 | train POD 0.1471627077053269 | train FAR 0.518176503574345
val loss 5.375141049437492 | val POD 0.1298282154146131 | val FAR 0.499450426930811

Epoch [ 44 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.3702176450037085 | train POD 0.15580607857371376 | train FAR 0.5179869232989721
val loss 5.375507102213917 | val POD 0.1405697887784064 | val FAR 0.4993113710121128

Epoch [ 45 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.369979699359061 | train POD 0.16128623833748768 | train FAR 0.5179606625980185
val loss 5.3757612324955595 | val POD 0.14728409553119648 | val FAR 0.4996374681854738

Epoch [ 46 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.369736956409983 | train POD 0.1637883168259728 | train FAR 0.518306016278384
val loss 5.375960004089014 | val POD 0.15606579568951695 | val FAR 0.49940407025000666

Epoch [ 47 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.369487988461305 | train POD 0.16851238875752944 | train FAR 0.5186356423347487
val loss 5.376204795956717 | val POD 0.16344419055108778 | val FAR 0.4993701380191976

Epoch [ 48 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.369244667131117 | train POD 0.1740101022423082 | train FAR 0.5187252235623757
val loss 5.376434206493818 | val POD 0.1735603973231121 | val FAR 0.49952075618518854

Epoch [ 49 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.368998569992329 | train POD 0.17878643620238702 | train FAR 0.5189427076869021
val loss 5.3766798279461305 | val POD 0.18352876931009499 | val FAR 0.4994182998088713

Epoch [ 50 / 50 ] 

100%|██████████| 22/22 [01:11<00:00,  3.26s/it]


train loss 5.368753049891968 | train POD 0.1832036260217417 | train FAR 0.5189567236874374
val loss 5.37694528491005 | val POD 0.19254363573932448 | val FAR 0.499462092090679



In [None]:
# x = torch.rand((32, 6, 1, 50, 50))
# enc = Encoder()
# h,c = enc(x)
# # print(h.size(1)+c.size(1))
# fus = Fusion(h.size(1))
# h,c = fus([h], [c])
# dec = Decoder()
# print("size of x[:,-1,:,:,:]", x[:,-1,:,:,:].size())
# print(h.size())
# out = dec(x[:,-1,:,:,:], h,c)
# print(h.size())
# print(c.size())
# print(torch.cat([h,c], dim=1).size())