In [44]:
import numpy as np
import cv2
import torch


n = 7
data_id = 1
n_time = 9
frame_max = 150
n_sample = frame_max - n_time

led_data = "../data/dataset/led/led_{}.csv".format(data_id)
lcd_data = "../data/dataset/lcd/lcd_{}.mp4".format(data_id)
spec_data = "../data/dataset/spc/spec_{}.csv".format(data_id)
led = np.loadtxt(led_data, dtype=np.int32)
spec = np.loadtxt(spec_data, dtype=np.float32)
lcd = cv2.VideoCapture(lcd_data)
lcd_shape = (int(lcd.get(cv2.CAP_PROP_FRAME_WIDTH)), int(lcd.get(cv2.CAP_PROP_FRAME_HEIGHT)))
lcd_data = torch.zeros((n, frame_max, lcd_shape[1], lcd_shape[0], 3), dtype=torch.uint8)
spec_data = torch.zeros((n, frame_max, spec.shape[0]))
led_data = torch.zeros((n, frame_max, led.shape[0]))
lcd_dataset = torch.zeros((n_sample * n, n_time, lcd_shape[1], lcd_shape[0], 3))
spec_dataset = torch.zeros((n_sample * n, n_time, spec.shape[0]))
led_dataset = torch.zeros((n_sample * n, n_time - 1, led.shape[0]))
led_correct = torch.zeros(n_sample * n, led.shape[0])

print(lcd_dataset.shape)
print(spec_dataset.shape)
print(led_dataset.shape)
print(led_correct.shape)

for data_id in range(1, n + 1):
    led_file = "../data/dataset/led/led_{}.csv".format(data_id)
    lcd_file = "../data/dataset/lcd/lcd_{}.mp4".format(data_id)
    spec_file = "../data/dataset/spc/spec_{}.csv".format(data_id)
    led = np.loadtxt(led_file, dtype=np.int32)
    spec = np.loadtxt(spec_file, dtype=np.float32)
    lcd = cv2.VideoCapture(lcd_file)
    frame_len = int(lcd.get(cv2.CAP_PROP_FRAME_COUNT))
    lcd_shape = (int(lcd.get(cv2.CAP_PROP_FRAME_WIDTH)), int(lcd.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    lcd_vec = np.zeros((frame_len, lcd_shape[1], lcd_shape[0], 3), dtype=np.uint8)
    for frame_idx in range(frame_len):
        lcd.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        e, frame = lcd.read()
        if not e:
            print("err")
        lcd_vec[frame_idx, :, :, :] = frame

    lcd_data[data_id - 1, 0:frame_len, :, :, :] = torch.tensor(lcd_vec)
    spec_data[data_id - 1, 0:frame_len:, :] = torch.tensor(spec.T)
    led_data[data_id - 1, 0:frame_len, :] = torch.tensor(led.T)

    for data_id in range(n):
        for sample_id in range(n_sample):
            outer_idx = n_sample * data_id + sample_id
            lcd_dataset[outer_idx] = lcd_data[data_id, sample_id:sample_id + n_time, :, :, :]
            led_dataset[outer_idx] = led_data[data_id, sample_id:sample_id + n_time - 1, :]
            spec_dataset[outer_idx] = spec_data[data_id, sample_id:sample_id + n_time, :]
            led_correct[outer_idx] = led_data[data_id, sample_id + n_time, :]
    

torch.Size([987, 9, 320, 436, 3])
torch.Size([987, 9, 128])
torch.Size([987, 8, 267])
torch.Size([987, 267])


In [53]:
import torch.nn as nn
c1 = nn.Conv1d(n_time - 1, 1, 3, 2)
c11 = nn.Conv1d(n_time, 1, 3, 1)
c2 = nn.Conv3d(n_time, 1, (3, 3, 1), (2, 2, 1)) 
c3 = nn.Conv3d(1, 1, (3, 3, 1), (2, 2, 1))
c4 = nn.Conv3d(1, 1, (3, 3, 1), (2, 2, 1))
y0 = c1(led_dataset[3:4])
print(led_dataset[3:4].shape)
y1 = c11(spec_dataset[3:4])
y2 = c2(lcd_dataset[10:11])
y3 = c3(y2)
y4 = c4(y3)
y5 = c4(y4)
y6 = c4(y5)
y7 = torch.reshape(y0, y0.shape[:1] + (-1,))
y8 = torch.reshape(y1, y1.shape[:1] + (-1,))
y9 = torch.reshape(y6, y6.shape[:1] + (-1,))
y10 = torch.cat([y7, y8, y9], 1)
print(y10.shape)
lstm = nn.LSTM(input_size=583, hidden_size=64, batch_first=True)
y11, h = lstm(y10)
y12 = nn.Linear(64, 133)(y11)
print(y12.shape)

torch.Size([1, 8, 267])
torch.Size([1, 583])
torch.Size([1, 133])


In [46]:
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(lcd_dataset, spec_dataset, led_dataset, led_correct)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

In [56]:
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.spec_conv = nn.Conv1d(
            n_time, n_time, 5
        )
        self.c1_1 = nn.Conv1d(n_time - 1, 1, 3, 2)
        self.c1_2 = nn.Conv1d(n_time, 1, 3, 1)
        self.c2 = nn.Conv3d(n_time, 4, (3, 3, 1), (2, 2, 1)) 
        self.c3 = nn.Conv3d(4, 1, (3, 3, 1), (2, 2, 1))
        self.c4 = nn.Conv3d(1, 1, (3, 3, 1), (2, 2, 1))
        self.lstm = nn.LSTM(input_size=583, hidden_size=64, batch_first=True)
        self.lin = nn.Linear(64, 267)
    
    def forward(self, lcd, spec, led):
        cled = self.c1_1(led)
        cspec = self.c1_2(spec)
        clcd1 = self.c2(lcd)
        clcd2 = self.c3(clcd1)
        clcd3 = self.c4(clcd2)
        clcd4 = self.c4(clcd3)
        clcd5 = self.c4(clcd4)
        ylcd = torch.reshape(clcd5, clcd5.shape[:1] + (-1,))
        yspec = torch.reshape(cspec, cspec.shape[:1] + (-1,))
        yled = torch.reshape(cled, cled.shape[:1] + (-1,))
        y0 = torch.cat([ylcd, yspec, yled], 1)
        y_rnn, h = self.lstm(y0, None)
        y = self.lin(y_rnn)
        return y

net = Net()
net.cuda()
print(net)

Net(
  (spec_conv): Conv1d(9, 9, kernel_size=(5,), stride=(1,))
  (c1_1): Conv1d(8, 1, kernel_size=(3,), stride=(2,))
  (c1_2): Conv1d(9, 1, kernel_size=(3,), stride=(1,))
  (c2): Conv3d(9, 4, kernel_size=(3, 3, 1), stride=(2, 2, 1))
  (c3): Conv3d(4, 1, kernel_size=(3, 3, 1), stride=(2, 2, 1))
  (c4): Conv3d(1, 1, kernel_size=(3, 3, 1), stride=(2, 2, 1))
  (lstm): LSTM(583, 64, batch_first=True)
  (lin): Linear(in_features=64, out_features=267, bias=True)
)


In [57]:
from torch import optim
import matplotlib.pyplot as plt

loss_fnc = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)
record_loss_train = []
epochs = 1000
for i in range(epochs):
    net.train()
    loss_train = 0
    for j, (lcd, spec, led, t) in enumerate(loader):
        lcd, spec, led, t = lcd.cuda(), spec.cuda(), led.cuda(), t.cuda()
        y = net(lcd, spec, led)
        loss = loss_fnc(y, t)
        loss_train += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_train /= j + 1
    record_loss_train.append(loss_train)
    print("epoch:\t{}\tloss_train={}".format(i, loss))

    # if i % 10 == 0 or i == epochs - 1:
    #     net.eval()
    #     print("Epoch: ", i, "Loss_Train: ", loss_train)
    #     predicted = list(input_data[0].view(-1))
    #     for i in range(n_sample):
    #         x = torch.tensor(predicted[-n_time:])
    #         x = x.view(1, n_time, 1)
    #         y = net(x)
    #         predicted.append(y[0].item())
    #     plt.plot(range(len(sin_y)), sin_y, label="Correct")
    #     plt.plot(range(len(predicted)), predicted, label="Predicted")
    #     plt.legend()
    #     plt.show()


epoch:	0	loss_train=40554.9453125
epoch:	1	loss_train=43018.25390625
epoch:	2	loss_train=37116.02734375
epoch:	3	loss_train=24659.580078125
epoch:	4	loss_train=37295.83984375
epoch:	5	loss_train=34040.8046875
epoch:	6	loss_train=18695.5859375
epoch:	7	loss_train=20375.572265625
epoch:	8	loss_train=29154.958984375
epoch:	9	loss_train=21997.498046875
epoch:	10	loss_train=25202.30078125
epoch:	11	loss_train=27062.83984375
epoch:	12	loss_train=26526.748046875
epoch:	13	loss_train=25273.1640625
epoch:	14	loss_train=21509.66015625
epoch:	15	loss_train=14719.8603515625
epoch:	16	loss_train=16456.759765625
epoch:	17	loss_train=14949.7802734375
epoch:	18	loss_train=10336.99609375
epoch:	19	loss_train=9161.630859375
epoch:	20	loss_train=14279.150390625
epoch:	21	loss_train=12044.2509765625
epoch:	22	loss_train=9546.791015625
epoch:	23	loss_train=9208.7197265625
epoch:	24	loss_train=17838.671875
epoch:	25	loss_train=10039.7060546875
epoch:	26	loss_train=11950.72265625
epoch:	27	loss_train=9041.56

KeyboardInterrupt: 