Skip to content

Commit

Permalink
lstm model fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ArpanBiswas99 committed Jan 22, 2024
1 parent 0bc342d commit d65b8b3
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions modeling_and_evaluation/SoC_Estimation_LSTM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,28 +148,21 @@
" def get_times(self):\n",
" return self.times\n",
"\n",
"# LSTM Model\n",
"class LSTM(nn.Module):\n",
"# SoCLSTM Model\n",
"class SoCLSTM(nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_layers):\n",
" super(LSTM, self).__init__()\n",
"\n",
" self.conv = nn.Conv1d(in_channels=1, out_channels=6, kernel_size=3, stride=1)\n",
" self.relu1 = nn.ReLU()\n",
" self.rnn = nn.LSTM(6, hidden_size, num_layers)\n",
" self.relu2 = nn.ReLU()\n",
" self.reg_1 = nn.Linear(hidden_size, 1)\n",
" self.reg_2 = nn.Linear(6, 1)\n",
" \n",
" super(SoCLSTM, self).__init__()\n",
" self.hidden_size = hidden_size\n",
" self.num_layers = num_layers\n",
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n",
" self.fc = nn.Linear(hidden_size, 1)\n",
"\n",
" def forward(self, x):\n",
" x = self.relu1(self.conv(x))\n",
" x, _ = self.rnn(x)\n",
" s, b, h = x.shape\n",
" x = x.view(s*b, h)\n",
" x = self.relu2(self.reg_1(x))\n",
" x = x.view(s, -1)\n",
" x = self.reg_2(x)\n",
" x = x.view(s, 1, 1)\n",
" return x\n",
" h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, dtype=x.dtype, device=x.device)\n",
" c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, dtype=x.dtype, device=x.device)\n",
" out, _ = self.lstm(x, (h0, c0))\n",
" out = self.fc(out[:, -1, :])\n",
" return out\n",
"\n",
"# Training loop with validation\n",
"def train_and_validate(model, criterion, optimizer, train_loader, val_loader, epochs, device, patience=20, min_delta=0.001):\n",
Expand Down Expand Up @@ -3039,7 +3032,7 @@
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "myenv"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down

0 comments on commit d65b8b3

Please sign in to comment.