In [1]:
!pip install -q librosa matplotlib spafe torch pandas

You should consider upgrading via the '/home/ashutosh/Desktop/ugmqa_project/venv/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

In [2]:
from torch.utils.data import random_split, Dataset, DataLoader

import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
import os
import librosa
import librosa.display
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
%matplotlib inline

from collections import defaultdict
from spafe.utils import vis
from spafe.features.lfcc import lfcc
import pandas as pd

In [4]:
import torch
import torch.nn as nn
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SelfAttention(nn.Module):

    def __init__(self, embed_size, num_heads):

        super(SelfAttention, self).__init__()

        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (self.head_dim * num_heads ==
                embed_size), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)

    def forward(self, value, key, query, mask):

        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        # Split embedding into self.num_heads pieces
        value = value.reshape(N, value_len, self.num_heads, self.head_dim)
        key = key.reshape(N, key_len, self.num_heads, self.head_dim)
        query = query.reshape(N, query_len, self.num_heads, self.head_dim)

        values = self.values(value)
        keys = self.keys(key)
        queries = self.queries(query)
        energy = torch.einsum(
            "nqhd,nkhd->nhqk", [queries, keys])  # MatMul Q and K
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, query_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)
        # print("Mask", mask.shape)
        # print("Energy", energy.shape)

        # energy = torch.zeros((N, self.num_heads, query_len, key_len)).to(device)

        # mask = torch.zeros((1, 1, 1, key_len)).to(device)

        if mask is not None:
            # print(mask)
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            # print(energy[0][0][0])

        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        # print("output_out:", out.shape)
        # print('Out shape', out.shape)
        out = out.reshape(
            N, query_len, self.num_heads * self.head_dim
        )

        out = self.fc_out(out)

        return out

In [5]:
class PositionalEncoding(nn.Module):

    def __init__(self, embed_size, dropout, max_len=5000):
        super().__init__()

        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2)
                             * (-math.log(10000.0) / embed_size))
        self.position_encoding = torch.zeros(max_len, embed_size).to(device)
        self.position_encoding[:, 0::2] = torch.sin(
            position * div_term).to(device)
        self.position_encoding[:, 1::2] = torch.cos(
            position * div_term).to(device)
        self.register_buffer('pe', self.position_encoding)

    def forward(self, x):
        # print('pe_x', x.shape)
        # print('pe', self.position_encoding.shape)
        x = x + self.position_encoding[:x.size(0)]
        return self.dropout(x)

In [6]:
class TransformerBlock(nn.Module):

    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()

        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1((attention + query)))

        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(x + forward))

        return out

In [7]:
class Encoder(nn.Module):

    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        # self.position_embedding = PositionalEncoding(embed_size, dropout, src_vocab_size)
        self.position_embedding = nn.Embedding(src_vocab_size, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        # print(N, seq_length)

        positions = torch.arange(0, seq_length).expand(
            N, seq_length).to(self.device)
        # print('positions shape', positions.shape)
        out = self.dropout(self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [8]:
class Transformer(nn.Module):

    def __init__(
        self,
        src_vocab_size,
        src_pad_index,
        embed_size=256,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        device="cuda",
        max_length=100
    ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size, embed_size, num_layers, heads,
            device, forward_expansion, dropout, max_length)

        self.src_pad_index = src_pad_index
        self.output = nn.Linear(src_vocab_size * embed_size, 1)
        self.device = device
        self.embed_size = embed_size
        self.src_vocab_size = src_vocab_size

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_index).unsqueeze(1).unsqueeze(2)

        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def forward(self, src):
        src_mask = self.make_src_mask(src)
        out = self.encoder(src, src_mask)

        out = out.reshape(-1, self.src_vocab_size * self.embed_size)

        return self.output(out)

In [9]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split


class AudioFeatureDataset(Dataset):
    def __init__(self, annotations_file, mode='train'):
        self.data = pd.read_csv(annotations_file)
        self.data = self.data.drop(['Unnamed: 0'], axis=1)

        # Splitting the dataset into train and validation sets
        total_samples = len(self.data)
        train_size = int(0.8 * total_samples)
        valid_size = total_samples - train_size

        if mode == 'train':
            self.data = self.data.iloc[:train_size]
        else:
            self.data = self.data.iloc[train_size:]

        self.features = torch.Tensor(self.data.drop(['class'], axis=1).values)
        self.labels = torch.Tensor(self.data['class'].values)
        self.mode = mode

    def __len__(self):
        # print(self.features.shape)
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


# Example usage
dataset = AudioFeatureDataset('./working_dataset.csv', mode='train')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# for batch in dataloader:
#     x, y = batch
#     # Your training code here
# x_train = dataset[0][0]
# x_target = dataset[0][1]

# print(len(dataset))
# print(len(dataset[0][0]))

dataset_val = AudioFeatureDataset('./working_dataset.csv', mode="val")
dataloader_val = DataLoader(dataset, batch_size=1, shuffle=True)
# print(len(dataset_val))

In [10]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)

In [11]:
num_layers = 1
src_vocab_size = 297  # TIME-STEPS
src_pad_index = 0
embed_size = 296  # D-Model
num_heads = 1
dropout = 0.1
output_size = 1
forward_expansion = 4

model = Transformer(
    src_vocab_size,
    src_pad_index,
    embed_size=embed_size,
    dropout=dropout,
    heads=num_heads,
    num_layers=num_layers,
    forward_expansion=forward_expansion
).to(device)

In [12]:
X = torch.Tensor([[0, 124.46949, -21.54154, 0, 16.483715, 36.6015, 0, -25.834284, 0, 10.178377, 0, -7.345142, 10.264649, 16.472473, 7.895236, 5.5313177, 6.2877584, -4.182223, -14.012452, -7.7094264, 0.4547759, -5.8861904, -13.388045, -6.9823565, 1.1916866, -1.7707733, -5.276302, -0.35934502, 1.7799568, -4.779073, -6.8897967, 0.79337484, 4.2686896, -2.5059521, -6.496715, -1.0943556, 3.1531215, 0.25546533, -2.272549, -0.73324627, 0.0020391261, 0.012858815, 0.062261246, 0.030906683, 0.033696823, 0.06754137, 0.054984007, 0.36406443, 0.18963988, 0.10330482, 0.12989527, 0.42003468, 0.48301712, 0.11749379, 0.06369407, 0.12233541, 0.2568086, 0.5382474, 1.106481, 1.4184383, 0.71715385, 0.35170192, 0.09534879, 0.0366585, 0.102917545, 0.087631606, 0.03625296, 0.06454905, 0.095770694, 0.12915519, 0.25286278, 0.2482152, 0.20051777, 0.39578456, 0.28170255, 0.3759659, 0.4401238, 0.90044194, 0.8636782, 1.1747689, 0.7714911, 0.818637, 0.37247756, 0.3380456, 0.7661559, 2.4523528, 2.7824976, 8.302962, 16.734638, 8.80889, 12.927655, 13.964713, 7.212821, 7.622612, 17.17792, 10.23075, 8.3349285, 5.6397433, 3.6112895, 0.6386233, 0.24162434, 0.34277353, 0.36261582, 0.3141538, 0.38225174, 0.47966504, 0.03040914, 0.0012386683, 9.3100425e-05, 7.8610115e-05, 7.8027195e-05, 5.490173e-06, 1.6225702e-05, 5.6767494e-06, 4.4215603e-06, 9.099585e-06, 1.2576394e-06, 2.412565e-06, 7.926269e-07, 2.0946077e-06, 1.3263315e-06, 6.3015705e-07, 4.1206027e-07, 7.42592e-07, 5.666581e-07, 2.967201e-07, 1.7622109e-07, 4.3961313e-07, 1.7293216e-07, 1.4194566e-07, 1.8899503e-07, 1.5307475e-07, 9.427612e-08, 1.06292724e-07, 1.01943236e-07, 6.155833e-08, 7.990511e-08, 6.0735445e-08, 4.232562e-08, 6.248923e-08, 3.7052516e-08, 4.0408686e-08, 3.633701e-08, 2.8342097e-08, 3.3516038e-08, 2.3205962e-08, 2.8330943e-08, 2.0632546e-08, 2.3197451e-08, 1.7732157e-08, 1.965794e-08, 1.6046194e-08, 1.7022142e-08, 1.4855194e-08, 1.4529622e-08, 1.42051055e-08, 1.3379163e-08, 1.3088586e-08, 1.2451093e-08, 1.209253e-08, 1.2065797e-08, 1.1807311e-08, 1.1646068e-08, 1.16453025e-08, 1.1936379e-08, 1.2163221e-08, 1.1026455e-08, 7.149512e-09, 0.36461177, 0.34337756, 0.27882302, 0.48984087, 0.58728004, 0.33111906, 0.35057455, 0.40294, 0.37982342, 0.38546938, 0.35827973, 0.3449619, 4098.515885613107,
                 4394.679597751644, 2173.6066930992697, 660.1920797101604, 666.3801569895414, 679.7044596914801, 646.1064276646537, 678.6052827762849, 636.4525179151398, 588.8034579300204, 588.6927126149499, 605.9566149564441, 554.6815422905305, 643.55633138673, 635.2709118421336, 593.1425591116601, 689.9515513706316, 753.8378337006548, 713.6116050230352, 665.575526379199, 639.8702521661144, 566.4381100383742, 602.769675543065, 636.9758223614031, 838.8391505866864, 1478.874805344335, 1459.149044006308, 1387.6915345139455, 1353.654060604915, 1380.012526068448, 1383.4946564371012, 1341.579689123074, 1427.5678622884227, 1397.7808753654772, 1344.037783237732, 1257.285128280657, 1297.332690017431, 1316.4388778935704, 1307.883347941082, 1283.9768778672214, 1273.1843530634096, 1278.0824064357296, 1321.431115766408, 1354.4814820287468, 1333.5878041563171, 1357.2018216795293, 1347.5298328353635, 1317.862385602503, 1316.7824410928667, 1363.7911817467896, 1421.45053482153, 1473.7295655749122, 1513.3380612007554, 1527.07320879706, 1520.3705478366887, 1481.8238543145394, 1500.151441619609, 1508.9669344602714, 1481.524474097877, 1445.1450034130887, 1417.4826755078072, 1453.824676109189, 1476.554974221759, 1457.7835067613787, 1442.5007903703054, 1446.138078746579, 1480.3479277869212, 1474.320222966424, 1423.3878323139418, 1360.8577255891676, 1337.3912372750488, 1400.7434254248976, 1406.7530196210892, 1371.0172331395806, 1360.0256467642882, 1360.1123709007084, 1359.7692618475817, 1349.6076963128105, 1345.1738370383607, 1343.4491732619308, 1328.2974945469014, 1314.617505626696, 1321.0909928662404, 1335.4841595848136, 1343.4321393541222, 1344.775088467436, 1333.9444077975136, 1305.7116252799508, 1309.2780904215947, 1308.5759110007473, 1306.0999714339998, 1355.6284484769956, 1452.5165920413374, 1475.6366023948713, 1447.970350106245, 1406.542142884293, 1377.6039021203871, 1278.4930177581896, 1189.0178172699625, 1205.2225452548764, 1179.4867044106006, 881.8538957920834, 1197.5584167943596, 1324.3785722370592, 1361.4066076471468, 1263.928152316742, 1005.8203163981548, 759.0834395666938, 679.2875862507892, 708.2743155303299, 712.1759249973968, 604.9816438504809, 659.412993337054, 789.7184528406086, 815.1398626985261, 915.8414552062778, 773.3202934324213],])
# [-397.86472,124.46949,0,-48.849564,16.483715,36.6015,-6.4991527,-25.834284,-1.2522386,10.178377,-5.218914,-7.345142,10.264649,16.472473,7.895236,5.5313177,6.2877584,-4.182223,-14.012452,-7.7094264,0.4547759,-5.8861904,-13.388045,-6.9823565,1.1916866,-1.7707733,-5.276302,-0.35934502,1.7799568,-4.779073,-6.8897967,0.79337484,4.2686896,-2.5059521,-6.496715,-1.0943556,3.1531215,0.25546533,-2.272549,-0.73324627,0.0020391261,0.012858815,0.062261246,0.030906683,0.033696823,0.06754137,0.054984007,0.36406443,0.18963988,0.10330482,0.12989527,0.42003468,0.48301712,0.11749379,0.06369407,0.12233541,0.2568086,0.5382474,1.106481,1.4184383,0.71715385,0.35170192,0.09534879,0.0366585,0.102917545,0.087631606,0.03625296,0.06454905,0.095770694,0.12915519,0.25286278,0.2482152,0.20051777,0.39578456,0.28170255,0.3759659,0.4401238,0.90044194,0.8636782,1.1747689,0.7714911,0.818637,0.37247756,0.3380456,0.7661559,2.4523528,2.7824976,8.302962,16.734638,8.80889,12.927655,13.964713,7.212821,7.622612,17.17792,10.23075,8.3349285,5.6397433,3.6112895,0.6386233,0.24162434,0.34277353,0.36261582,0.3141538,0.38225174,0.47966504,0.03040914,0.0012386683,9.3100425e-05,7.8610115e-05,7.8027195e-05,5.490173e-06,1.6225702e-05,5.6767494e-06,4.4215603e-06,9.099585e-06,1.2576394e-06,2.412565e-06,7.926269e-07,2.0946077e-06,1.3263315e-06,6.3015705e-07,4.1206027e-07,7.42592e-07,5.666581e-07,2.967201e-07,1.7622109e-07,4.3961313e-07,1.7293216e-07,1.4194566e-07,1.8899503e-07,1.5307475e-07,9.427612e-08,1.06292724e-07,1.01943236e-07,6.155833e-08,7.990511e-08,6.0735445e-08,4.232562e-08,6.248923e-08,3.7052516e-08,4.0408686e-08,3.633701e-08,2.8342097e-08,3.3516038e-08,2.3205962e-08,2.8330943e-08,2.0632546e-08,2.3197451e-08,1.7732157e-08,1.965794e-08,1.6046194e-08,1.7022142e-08,1.4855194e-08,1.4529622e-08,1.42051055e-08,1.3379163e-08,1.3088586e-08,1.2451093e-08,1.209253e-08,1.2065797e-08,1.1807311e-08,1.1646068e-08,1.16453025e-08,1.1936379e-08,1.2163221e-08,1.1026455e-08,7.149512e-09,0.36461177,0.34337756,0.27882302,0.48984087,0.58728004,0.33111906,0.35057455,0.40294,0.37982342,0.38546938,0.35827973,0.3449619,4098.515885613107,4394.679597751644,2173.6066930992697,660.1920797101604,666.3801569895414,679.7044596914801,646.1064276646537,678.6052827762849,636.4525179151398,588.8034579300204,588.6927126149499,605.9566149564441,554.6815422905305,643.55633138673,635.2709118421336,593.1425591116601,689.9515513706316,753.8378337006548,713.6116050230352,665.575526379199,639.8702521661144,566.4381100383742,602.769675543065,636.9758223614031,838.8391505866864,1478.874805344335,1459.149044006308,1387.6915345139455,1353.654060604915,1380.012526068448,1383.4946564371012,1341.579689123074,1427.5678622884227,1397.7808753654772,1344.037783237732,1257.285128280657,1297.332690017431,1316.4388778935704,1307.883347941082,1283.9768778672214,1273.1843530634096,1278.0824064357296,1321.431115766408,1354.4814820287468,1333.5878041563171,1357.2018216795293,1347.5298328353635,1317.862385602503,1316.7824410928667,1363.7911817467896,1421.45053482153,1473.7295655749122,1513.3380612007554,1527.07320879706,1520.3705478366887,1481.8238543145394,1500.151441619609,1508.9669344602714,1481.524474097877,1445.1450034130887,1417.4826755078072,1453.824676109189,1476.554974221759,1457.7835067613787,1442.5007903703054,1446.138078746579,1480.3479277869212,1474.320222966424,1423.3878323139418,1360.8577255891676,1337.3912372750488,1400.7434254248976,1406.7530196210892,1371.0172331395806,1360.0256467642882,1360.1123709007084,1359.7692618475817,1349.6076963128105,1345.1738370383607,1343.4491732619308,1328.2974945469014,1314.617505626696,1321.0909928662404,1335.4841595848136,1343.4321393541222,1344.775088467436,1333.9444077975136,1305.7116252799508,1309.2780904215947,1308.5759110007473,1306.0999714339998,1355.6284484769956,1452.5165920413374,1475.6366023948713,1447.970350106245,1406.542142884293,1377.6039021203871,1278.4930177581896,1189.0178172699625,1205.2225452548764,1179.4867044106006,881.8538957920834,1197.5584167943596,1324.3785722370592,1361.4066076471468,1263.928152316742,1005.8203163981548,759.0834395666938,679.2875862507892,708.2743155303299,712.1759249973968,604.9816438504809,659.412993337054,789.7184528406086,815.1398626985261,915.8414552062778,773.3202934324213],
# [-397.86472,0,-21.54154,-48.849564,16.483715,36.6015,-6.4991527,-25.834284,-1.2522386,10.178377,-5.218914,-7.345142,10.264649,16.472473,7.895236,5.5313177,6.2877584,-4.182223,-14.012452,-7.7094264,0.4547759,-5.8861904,-13.388045,-6.9823565,1.1916866,-1.7707733,-5.276302,-0.35934502,1.7799568,-4.779073,-6.8897967,0.79337484,4.2686896,-2.5059521,-6.496715,-1.0943556,3.1531215,0.25546533,-2.272549,-0.73324627,0.0020391261,0.012858815,0.062261246,0.030906683,0.033696823,0.06754137,0.054984007,0.36406443,0.18963988,0.10330482,0.12989527,0.42003468,0.48301712,0.11749379,0.06369407,0.12233541,0.2568086,0.5382474,1.106481,1.4184383,0.71715385,0.35170192,0.09534879,0.0366585,0.102917545,0.087631606,0.03625296,0.06454905,0.095770694,0.12915519,0.25286278,0.2482152,0.20051777,0.39578456,0.28170255,0.3759659,0.4401238,0.90044194,0.8636782,1.1747689,0.7714911,0.818637,0.37247756,0.3380456,0.7661559,2.4523528,2.7824976,8.302962,16.734638,8.80889,12.927655,13.964713,7.212821,7.622612,17.17792,10.23075,8.3349285,5.6397433,3.6112895,0.6386233,0.24162434,0.34277353,0.36261582,0.3141538,0.38225174,0.47966504,0.03040914,0.0012386683,9.3100425e-05,7.8610115e-05,7.8027195e-05,5.490173e-06,1.6225702e-05,5.6767494e-06,4.4215603e-06,9.099585e-06,1.2576394e-06,2.412565e-06,7.926269e-07,2.0946077e-06,1.3263315e-06,6.3015705e-07,4.1206027e-07,7.42592e-07,5.666581e-07,2.967201e-07,1.7622109e-07,4.3961313e-07,1.7293216e-07,1.4194566e-07,1.8899503e-07,1.5307475e-07,9.427612e-08,1.06292724e-07,1.01943236e-07,6.155833e-08,7.990511e-08,6.0735445e-08,4.232562e-08,6.248923e-08,3.7052516e-08,4.0408686e-08,3.633701e-08,2.8342097e-08,3.3516038e-08,2.3205962e-08,2.8330943e-08,2.0632546e-08,2.3197451e-08,1.7732157e-08,1.965794e-08,1.6046194e-08,1.7022142e-08,1.4855194e-08,1.4529622e-08,1.42051055e-08,1.3379163e-08,1.3088586e-08,1.2451093e-08,1.209253e-08,1.2065797e-08,1.1807311e-08,1.1646068e-08,1.16453025e-08,1.1936379e-08,1.2163221e-08,1.1026455e-08,7.149512e-09,0.36461177,0.34337756,0.27882302,0.48984087,0.58728004,0.33111906,0.35057455,0.40294,0.37982342,0.38546938,0.35827973,0.3449619,4098.515885613107,4394.679597751644,2173.6066930992697,660.1920797101604,666.3801569895414,679.7044596914801,646.1064276646537,678.6052827762849,636.4525179151398,588.8034579300204,588.6927126149499,605.9566149564441,554.6815422905305,643.55633138673,635.2709118421336,593.1425591116601,689.9515513706316,753.8378337006548,713.6116050230352,665.575526379199,639.8702521661144,566.4381100383742,602.769675543065,636.9758223614031,838.8391505866864,1478.874805344335,1459.149044006308,1387.6915345139455,1353.654060604915,1380.012526068448,1383.4946564371012,1341.579689123074,1427.5678622884227,1397.7808753654772,1344.037783237732,1257.285128280657,1297.332690017431,1316.4388778935704,1307.883347941082,1283.9768778672214,1273.1843530634096,1278.0824064357296,1321.431115766408,1354.4814820287468,1333.5878041563171,1357.2018216795293,1347.5298328353635,1317.862385602503,1316.7824410928667,1363.7911817467896,1421.45053482153,1473.7295655749122,1513.3380612007554,1527.07320879706,1520.3705478366887,1481.8238543145394,1500.151441619609,1508.9669344602714,1481.524474097877,1445.1450034130887,1417.4826755078072,1453.824676109189,1476.554974221759,1457.7835067613787,1442.5007903703054,1446.138078746579,1480.3479277869212,1474.320222966424,1423.3878323139418,1360.8577255891676,1337.3912372750488,1400.7434254248976,1406.7530196210892,1371.0172331395806,1360.0256467642882,1360.1123709007084,1359.7692618475817,1349.6076963128105,1345.1738370383607,1343.4491732619308,1328.2974945469014,1314.617505626696,1321.0909928662404,1335.4841595848136,1343.4321393541222,1344.775088467436,1333.9444077975136,1305.7116252799508,1309.2780904215947,1308.5759110007473,1306.0999714339998,1355.6284484769956,1452.5165920413374,1475.6366023948713,1447.970350106245,1406.542142884293,1377.6039021203871,1278.4930177581896,1189.0178172699625,1205.2225452548764,1179.4867044106006,881.8538957920834,1197.5584167943596,1324.3785722370592,1361.4066076471468,1263.928152316742,1005.8203163981548,759.0834395666938,679.2875862507892,708.2743155303299,712.1759249973968,604.9816438504809,659.412993337054,789.7184528406086,815.1398626985261,915.8414552062778,773.3202934324213],
# [0,124.46949,-21.54154,-48.849564,16.483715,36.6015,-6.4991527,-25.834284,-1.2522386,10.178377,-5.218914,-7.345142,10.264649,16.472473,7.895236,5.5313177,6.2877584,-4.182223,-14.012452,-7.7094264,0.4547759,-5.8861904,-13.388045,-6.9823565,1.1916866,-1.7707733,-5.276302,-0.35934502,1.7799568,-4.779073,-6.8897967,0.79337484,4.2686896,-2.5059521,-6.496715,-1.0943556,3.1531215,0.25546533,-2.272549,-0.73324627,0.0020391261,0.012858815,0.062261246,0.030906683,0.033696823,0.06754137,0.054984007,0.36406443,0.18963988,0.10330482,0.12989527,0.42003468,0.48301712,0.11749379,0.06369407,0.12233541,0.2568086,0.5382474,1.106481,1.4184383,0.71715385,0.35170192,0.09534879,0.0366585,0.102917545,0.087631606,0.03625296,0.06454905,0.095770694,0.12915519,0.25286278,0.2482152,0.20051777,0.39578456,0.28170255,0.3759659,0.4401238,0.90044194,0.8636782,1.1747689,0.7714911,0.818637,0.37247756,0.3380456,0.7661559,2.4523528,2.7824976,8.302962,16.734638,8.80889,12.927655,13.964713,7.212821,7.622612,17.17792,10.23075,8.3349285,5.6397433,3.6112895,0.6386233,0.24162434,0.34277353,0.36261582,0.3141538,0.38225174,0.47966504,0.03040914,0.0012386683,9.3100425e-05,7.8610115e-05,7.8027195e-05,5.490173e-06,1.6225702e-05,5.6767494e-06,4.4215603e-06,9.099585e-06,1.2576394e-06,2.412565e-06,7.926269e-07,2.0946077e-06,1.3263315e-06,6.3015705e-07,4.1206027e-07,7.42592e-07,5.666581e-07,2.967201e-07,1.7622109e-07,4.3961313e-07,1.7293216e-07,1.4194566e-07,1.8899503e-07,1.5307475e-07,9.427612e-08,1.06292724e-07,1.01943236e-07,6.155833e-08,7.990511e-08,6.0735445e-08,4.232562e-08,6.248923e-08,3.7052516e-08,4.0408686e-08,3.633701e-08,2.8342097e-08,3.3516038e-08,2.3205962e-08,2.8330943e-08,2.0632546e-08,2.3197451e-08,1.7732157e-08,1.965794e-08,1.6046194e-08,1.7022142e-08,1.4855194e-08,1.4529622e-08,1.42051055e-08,1.3379163e-08,1.3088586e-08,1.2451093e-08,1.209253e-08,1.2065797e-08,1.1807311e-08,1.1646068e-08,1.16453025e-08,1.1936379e-08,1.2163221e-08,1.1026455e-08,7.149512e-09,0.36461177,0.34337756,0.27882302,0.48984087,0.58728004,0.33111906,0.35057455,0.40294,0.37982342,0.38546938,0.35827973,0.3449619,4098.515885613107,4394.679597751644,2173.6066930992697,660.1920797101604,666.3801569895414,679.7044596914801,646.1064276646537,678.6052827762849,636.4525179151398,588.8034579300204,588.6927126149499,605.9566149564441,554.6815422905305,643.55633138673,635.2709118421336,593.1425591116601,689.9515513706316,753.8378337006548,713.6116050230352,665.575526379199,639.8702521661144,566.4381100383742,602.769675543065,636.9758223614031,838.8391505866864,1478.874805344335,1459.149044006308,1387.6915345139455,1353.654060604915,1380.012526068448,1383.4946564371012,1341.579689123074,1427.5678622884227,1397.7808753654772,1344.037783237732,1257.285128280657,1297.332690017431,1316.4388778935704,1307.883347941082,1283.9768778672214,1273.1843530634096,1278.0824064357296,1321.431115766408,1354.4814820287468,1333.5878041563171,1357.2018216795293,1347.5298328353635,1317.862385602503,1316.7824410928667,1363.7911817467896,1421.45053482153,1473.7295655749122,1513.3380612007554,1527.07320879706,1520.3705478366887,1481.8238543145394,1500.151441619609,1508.9669344602714,1481.524474097877,1445.1450034130887,1417.4826755078072,1453.824676109189,1476.554974221759,1457.7835067613787,1442.5007903703054,1446.138078746579,1480.3479277869212,1474.320222966424,1423.3878323139418,1360.8577255891676,1337.3912372750488,1400.7434254248976,1406.7530196210892,1371.0172331395806,1360.0256467642882,1360.1123709007084,1359.7692618475817,1349.6076963128105,1345.1738370383607,1343.4491732619308,1328.2974945469014,1314.617505626696,1321.0909928662404,1335.4841595848136,1343.4321393541222,1344.775088467436,1333.9444077975136,1305.7116252799508,1309.2780904215947,1308.5759110007473,1306.0999714339998,1355.6284484769956,1452.5165920413374,1475.6366023948713,1447.970350106245,1406.542142884293,1377.6039021203871,1278.4930177581896,1189.0178172699625,1205.2225452548764,1179.4867044106006,881.8538957920834,1197.5584167943596,1324.3785722370592,1361.4066076471468,1263.928152316742,1005.8203163981548,759.0834395666938,679.2875862507892,708.2743155303299,712.1759249973968,604.9816438504809,659.412993337054,789.7184528406086,815.1398626985261,915.8414552062778,773.3202934324213]])

print(X.shape)
# y = torch.Tensor([3.21]).to(device)

pred = model(X).squeeze()
print(pred)
# loss_fn = nn.MSELoss()
# print(loss_fn(pred, y))

torch.Size([1, 297])
tensor(-0.8205, device='cuda:0', grad_fn=<SqueezeBackward0>)


In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.000003)

loss_fn = nn.MSELoss()

running_loss = 0.
last_loss = 0.

model.train()

epochs = 1000

for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        # print('X', X.shape)
        X = X.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        pred = model(X)
        # print("Prediction")
        pred = pred.squeeze()
        loss = loss_fn(pred, y)

        # losses.append(loss.to('cpu').detach().numpy())
        # iterations += 1

        # Backpropagation
        loss.backward()
        # print("Loss gradient", loss.grad)
        optimizer.step()

        running_loss += loss.item()

        if batch % 10 == 0:
            last_loss = running_loss / 100  # loss per batch
            print('batch {} loss: {}'.format(batch + 1, last_loss))
            # tb_x = epoch * len(dataloader) + batch + 1
            # tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    # model.eval()
    # size = len(dataloader_val.dataset)
    # num_batches = len(dataloader_val)
    # test_loss, correct = 0, 0

    # with torch.no_grad():
    #   for batch, (X, y) in enumerate(dataloader_val):
    #       X = X.to(device)
    #       y = y.to(device)
    #       pred = model(X)
    #       test_loss += loss_fn(pred, y).item()
    #       correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    # test_loss /= num_batches
    # correct /= size
    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    break

print('Done')

Epoch 1
-------------------------------
batch 1 loss: 0.3137862396240234


  return F.mse_loss(input, target, reduction=self.reduction)


batch 11 loss: 0.8799844142200891
batch 21 loss: 0.3316475596372038
batch 31 loss: 0.18205110877053812
batch 41 loss: 0.15711148275062442
batch 51 loss: 0.16122780553996563
batch 61 loss: 0.1625233540865662
batch 71 loss: 0.15334290117025376
batch 81 loss: 0.1982310657016933
batch 91 loss: 0.25095088094472884
batch 101 loss: 0.11188910387456417
batch 111 loss: 0.35187231820076703
batch 121 loss: 0.15820662846788763
batch 131 loss: 0.23669531404972077
batch 141 loss: 0.2049787424504757
batch 151 loss: 0.14772035262431019
batch 161 loss: 0.11071514077484608
batch 171 loss: 0.24033718943595886
batch 181 loss: 0.2651617312431335
batch 191 loss: 0.17597591083496808
batch 201 loss: 0.09521109985653312
batch 211 loss: 0.15024343252182007
batch 221 loss: 0.19694268659513908
batch 231 loss: 0.1372122206632048
batch 241 loss: 0.20314003211795353
batch 251 loss: 0.20448532313108445
batch 261 loss: 0.25657039068639276
batch 271 loss: 0.2921939765755087
batch 281 loss: 0.14327034532674587
batch 291

In [14]:
# run
import matplotlib.pyplot as plt

plt.plot(losses)
plt.xlabel("Iterations")
plt.ylabel("Loss")
# plt.savefig("training.jpg")

NameError: name 'losses' is not defined