In [14]:
import torch
import torch.nn as nn
from data_process import get_data_loaders, getFeatures
# from model.TCN import TemporalConvNet, TCN
from tqdm import tqdm

In [3]:
data, train_loader, val_loader, test_loader = get_data_loaders()
features = getFeatures()

[0.0041988  0.00420218 0.00420386 ... 0.41108662 0.41109065 0.41109251]


In [67]:
for i, (x, y) in enumerate(train_loader):
    print(x.shape)
    print(y.shape)
    break

torch.Size([64, 10, 32])
torch.Size([64, 2])


In [75]:
print(y[0].shape)

torch.Size([2])


In [70]:
model = Encoder(32, 10, [64, 64, 64], 32)
h,c = model(x)
print(h.shape)
print(c.shape)
# print(seq2seq(x).shape

NameError: name 'seq2seq' is not defined

In [76]:
# build TCN-LSTM model

class Encoder(nn.Module):
    def __init__(self, input_size, seq_len, tcn_num_channels, lstm_num_hidden, tcn_kernel_size=2, tcn_dropout=0.2):
        super(Encoder, self).__init__()
        self.tcn = TemporalConvNet(input_size, tcn_num_channels, tcn_kernel_size, tcn_dropout)
        self.fc_feature = nn.Linear(tcn_num_channels[-1], lstm_num_hidden)
        self.fc_time = nn.Linear(seq_len, 1)
        
        self.lstm_num_hidden = lstm_num_hidden
    
    def forward(self, x):
        output = self.tcn(x.transpose(1, 2)) # (batch_size, tcn_num_channels[-1], seq_len)
        output = output.transpose(1, 2) # (batch_size, seq_len, tcn_num_channels[-1])
        output = self.fc_feature(output) # (batch_size, seq_len, lstm_num_hidden)

        h = output[:, -1, :] # (batch_size, lstm_num_hidden)

        c = output.transpose(1, 2) # (batch_size, lstm_num_hidden, seq_len)
        c = self.fc_time(c).squeeze(2) # (batch_size, lstm_num_hidden)
        
        return h, c


class Decoder(nn.Module):
    def __init__(self, input_size, seq_len, hidden_size, num_layers=1):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.num_layers = num_layers
        self.seq_len = seq_len

    def forward(self, x, hidden, cell):
        if self.num_layers > 1:
            # repeat the hidden states according to the number of layers
            hidden = hidden.repeat(self.num_layers, 1, 1)
            cell = cell.repeat(self.num_layers, 1, 1)

        outputs = []
        for _ in range(self.seq_len): 
            output, (hidden, cell) = self.lstm(x, (hidden, cell))
            output = self.fc(output) 
            outputs.append(output)

        outputs = torch.cat(outputs, dim=1)
        return outputs


class TCN_LSTM(nn.Module):
    def __init__(self, input_size, input_len, output_len, tcn_num_channels, lstm_num_hidden, tcn_kernel_size=2, tcn_dropout=0.2, num_layers=1):
        super(TCN_LSTM, self).__init__()
        self.encoder = Encoder(input_size, input_len, tcn_num_channels, lstm_num_hidden, tcn_kernel_size, tcn_dropout)
        self.decoder = Decoder(lstm_num_hidden, output_len, lstm_num_hidden, num_layers)

    def forward(self, x):
        # x: (batch_size, input_len, input_size)
        h, c = self.encoder(x)
        xt = x[:, -1, :].unsqueeze(1) # assume the last feature is the time feature
        outputs = self.decoder(xt, h, c) # (batch_size, output_len, 1)
        return outputs

In [79]:
from tqdm import tqdm

input_size = len(features)
input_len = 10
output_len = 2
hidden_size = 32  
num_layers = 3

encoder = Encoder(input_size, input_len, [64,64,64], hidden_size)
decoder = Decoder(hidden_size, output_len, hidden_size, num_layers=num_layers)
seq2seq = TCN_LSTM(input_size, input_len, output_len, [64,64,64], hidden_size, num_layers=num_layers)

# train the model
# Loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(seq2seq.parameters(), lr=0.0001)

# Training loop with validation and early stopping
num_epochs = 10 
best_epoch = 0
best_val_loss = float('inf')
train_losses, val_losses = [], []

for epoch in range(num_epochs):
    # Training phase
    seq2seq.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [TRAIN]")
    for inputs, targets in progress_bar:
        optimizer.zero_grad()
        outputs = seq2seq(inputs)  
        outputs = outputs.squeeze(-1) # (batch_size, output_len)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        progress_bar.set_postfix({'train_loss': loss.item()})

    average_train_loss = total_train_loss / len(train_loader)
    train_losses.append(average_train_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {average_train_loss:.4f}")

    # Validation phase
    seq2seq.eval()
    total_val_loss = 0
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [VAL]")
    for inputs, targets in progress_bar:
        outputs = seq2seq(inputs)  
        outputs = outputs.squeeze(-1) # (batch_size, output_len)

        loss = criterion(outputs, targets)

        total_val_loss += loss.item()
        progress_bar.set_postfix({'val_loss': loss.item()})
    
    average_val_loss = total_val_loss / len(val_loader)
    val_losses.append(average_val_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Validation Loss: {average_val_loss:.4f}")

    # Save the model with least validation loss
    if average_val_loss < best_val_loss:
        best_epoch = epoch + 1
        best_val_loss = average_val_loss
        torch.save(seq2seq.state_dict(), 'TCN_LSTM_best_model.pt')
        


Epoch 1/10 [TRAIN]:   0%|          | 0/2418 [00:00<?, ?it/s, train_loss=0.12]

Epoch 1/10 [TRAIN]: 100%|██████████| 2418/2418 [00:44<00:00, 54.32it/s, train_loss=0.000365]


Epoch 1/10, Average Training Loss: 0.0059


Epoch 1/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 148.85it/s, val_loss=0.000146]


Epoch 1/10, Average Validation Loss: 0.0006


Epoch 2/10 [TRAIN]: 100%|██████████| 2418/2418 [00:46<00:00, 52.32it/s, train_loss=8.86e-5] 


Epoch 2/10, Average Training Loss: 0.0006


Epoch 2/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 151.22it/s, val_loss=5.48e-5] 


Epoch 2/10, Average Validation Loss: 0.0004


Epoch 3/10 [TRAIN]: 100%|██████████| 2418/2418 [00:47<00:00, 50.48it/s, train_loss=5.06e-5] 


Epoch 3/10, Average Training Loss: 0.0005


Epoch 3/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 148.85it/s, val_loss=4.45e-5]


Epoch 3/10, Average Validation Loss: 0.0004


Epoch 4/10 [TRAIN]: 100%|██████████| 2418/2418 [00:47<00:00, 50.52it/s, train_loss=2.98e-5] 


Epoch 4/10, Average Training Loss: 0.0004


Epoch 4/10 [VAL]: 100%|██████████| 806/806 [00:06<00:00, 126.34it/s, val_loss=4.43e-5]


Epoch 4/10, Average Validation Loss: 0.0004


Epoch 5/10 [TRAIN]: 100%|██████████| 2418/2418 [00:48<00:00, 49.71it/s, train_loss=2.05e-5] 


Epoch 5/10, Average Training Loss: 0.0004


Epoch 5/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 140.99it/s, val_loss=3.65e-5]


Epoch 5/10, Average Validation Loss: 0.0004


Epoch 6/10 [TRAIN]: 100%|██████████| 2418/2418 [00:48<00:00, 49.45it/s, train_loss=2.05e-5] 


Epoch 6/10, Average Training Loss: 0.0004


Epoch 6/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 142.27it/s, val_loss=3.81e-5]


Epoch 6/10, Average Validation Loss: 0.0004


Epoch 7/10 [TRAIN]: 100%|██████████| 2418/2418 [00:48<00:00, 49.81it/s, train_loss=1.24e-5] 


Epoch 7/10, Average Training Loss: 0.0004


Epoch 7/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 139.50it/s, val_loss=1.51e-5]


Epoch 7/10, Average Validation Loss: 0.0004


Epoch 8/10 [TRAIN]: 100%|██████████| 2418/2418 [00:52<00:00, 46.16it/s, train_loss=1.03e-5] 


Epoch 8/10, Average Training Loss: 0.0004


Epoch 8/10 [VAL]: 100%|██████████| 806/806 [00:06<00:00, 121.91it/s, val_loss=1.69e-5]


Epoch 8/10, Average Validation Loss: 0.0004


Epoch 9/10 [TRAIN]: 100%|██████████| 2418/2418 [00:52<00:00, 45.78it/s, train_loss=2.03e-5] 


Epoch 9/10, Average Training Loss: 0.0004


Epoch 9/10 [VAL]: 100%|██████████| 806/806 [00:05<00:00, 136.58it/s, val_loss=3.46e-5]


Epoch 9/10, Average Validation Loss: 0.0004


Epoch 10/10 [TRAIN]: 100%|██████████| 2418/2418 [00:51<00:00, 46.94it/s, train_loss=1.11e-5] 


Epoch 10/10, Average Training Loss: 0.0004


Epoch 10/10 [VAL]: 100%|██████████| 806/806 [00:06<00:00, 132.40it/s, val_loss=3.67e-5]

Epoch 10/10, Average Validation Loss: 0.0004



