In [1]:
import torch.nn.functional as F
import torch
import pandas as pd
import os

In [2]:
num_classes = 9
input_chunk_length = 120

In [3]:
def dscovr_df(filename):
    df = pd.read_csv(os.path.join(os.path.curdir, "data", filename))
    df = df.drop(columns=["Unnamed: 0.1", "Unnamed: 0"])
    df.index = df["0"]
    return df.drop(columns="0")
# print(dscovr_df("data_2016.csv")["k_index_target"].unique().size)

In [4]:
from dscovr_dataset import create_dscovr_dataset
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
import pandas as pd

train_df = dscovr_df("data_2017.csv")
X_train, y_train = create_dscovr_dataset(train_df.to_numpy(), lookback=input_chunk_length, num_classes=num_classes)
data_train = data.TensorDataset(X_train, y_train)

test_df = dscovr_df("data_2020.csv")
X_test, y_test = create_dscovr_dataset(test_df.to_numpy(), lookback=input_chunk_length, num_classes=num_classes)
data_test = data.TensorDataset(X_test, y_test)

batch_size = 32
train_loader = data.DataLoader(data_train, shuffle=True, batch_size=batch_size)
test_loader = data.DataLoader(data_test, shuffle=True, batch_size=batch_size * 8)

100%|██████████| 10000/10000 [00:00<00:00, 31635.36it/s]
100%|██████████| 4410/4410 [00:00<00:00, 68978.32it/s]


In [5]:
from transformer_classifier import TransformerClassifier
model = TransformerClassifier(
    num_classes=num_classes,
    input_chunk_length = input_chunk_length,
    d_model = 64,
    input_dim=54,
    nhead = 8,
    dim_feedforward = 256,
    num_layers = 4,
    dropout = 0.5,
    activation = "relu",
    classifier_dropout = 0.5,
).to(device="cuda")
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-5)

In [6]:
from tqdm import tqdm
from torchmetrics.functional import precision, recall, accuracy

n_epochs = 50
for epoch in range(n_epochs):
    model.train()
    prec, rec, acc = [], [], []
    run_loss = 0
    for X_batch, y_batch in tqdm(train_loader):
        X_batch, y_batch = X_batch.to(device="cuda"), y_batch.to(device="cuda", dtype=torch.float)
        y_pred = model(X_batch)

        optimizer.zero_grad()
        loss = loss_fn(y_pred, y_batch)

        y_pred = y_pred.argmax(dim=1)
        y_batch = y_batch.argmax(dim=1)
        curr_prec = precision(y_pred, y_batch, "multiclass", num_classes=num_classes, average="macro")
        curr_rec = recall(y_pred, y_batch, "multiclass", num_classes=num_classes, average="macro")
        curr_acc = accuracy(y_pred, y_batch, "multiclass", num_classes=num_classes, average="macro")
        prec.append(curr_prec.item()), rec.append(curr_rec.item()), acc.append(curr_acc.item())
        run_loss += loss.item()

        loss.backward()
        optimizer.step()
    print(f"For index {epoch}: precision: {sum(prec)/len(prec)} recall: {sum(rec)/len(rec)} acc: {sum(acc)/len(acc)}")
    print("--------------------------------")

    model.eval()
    if epoch % 5 == 4:
        prec, rec, acc = [], [], []
        with torch.no_grad():
            for i, (inputs, targets) in tqdm(enumerate(test_loader)):
                inputs, targets = inputs.to(device="cuda"), targets.to(device="cuda")
                y_pred = model(inputs)

                y_pred = y_pred.argmax(dim=1)
                targets = targets.argmax(dim=1)
                curr_prec = precision(y_pred, targets, "multiclass", num_classes=num_classes, average="macro")
                curr_rec = recall(y_pred, targets, "multiclass", num_classes=num_classes, average="macro")
                curr_acc = accuracy(y_pred, targets, "multiclass", num_classes=num_classes, average="macro")
                prec.append(curr_prec.item()), rec.append(curr_rec.item()), acc.append(curr_acc.item())

        # Print accuracy
        print(f"For test: precision: {sum(prec)/len(prec)} recall: {sum(rec)/len(rec)} acc: {sum(acc)/len(acc)}")
        print("--------------------------------")

100%|██████████| 302/302 [00:23<00:00, 13.10it/s]


For index 0: precision: 0.22721151977974846 recall: 0.15172735342620225 acc: 0.15172735342620225
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.25it/s]


For index 1: precision: 0.2794771987747475 recall: 0.18005449706010077 acc: 0.18005449706010077
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.28it/s]


For index 2: precision: 0.2884467894470455 recall: 0.18331830462576537 acc: 0.18331830462576537
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.27it/s]


For index 3: precision: 0.34313317538787985 recall: 0.20668688483092168 acc: 0.20668688483092168
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.27it/s]


For index 4: precision: 0.3663358415002065 recall: 0.20898592930953233 acc: 0.20898592930953233
--------------------------------


8it [00:01,  4.80it/s]


For test: precision: 0.20031467080116272 recall: 0.1483655944466591 acc: 0.1483655944466591
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.29it/s]


For index 5: precision: 0.370423645225187 recall: 0.21691730201540405 acc: 0.21691730201540405
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.27it/s]


For index 6: precision: 0.38160586680304137 recall: 0.21247198741068904 acc: 0.21247198741068904
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.28it/s]


For index 7: precision: 0.41190002266539644 recall: 0.220831942664373 acc: 0.220831942664373
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.28it/s]


For index 8: precision: 0.46203969547290674 recall: 0.23832390797848732 acc: 0.23832390797848732
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.25it/s]


For index 9: precision: 0.46527536479842585 recall: 0.2324209889483373 acc: 0.2324209889483373
--------------------------------


8it [00:01,  4.82it/s]


For test: precision: 0.313056331127882 recall: 0.2563162576407194 acc: 0.2563162576407194
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.26it/s]


For index 10: precision: 0.4839833766123317 recall: 0.24709317796179 acc: 0.24709317796179
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.28it/s]


For index 11: precision: 0.4769554969096026 recall: 0.24158489136703756 acc: 0.24158489136703756
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.27it/s]


For index 12: precision: 0.49434266120984854 recall: 0.2522787753961339 acc: 0.2522787753961339
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.24it/s]


For index 13: precision: 0.4931163349195032 recall: 0.2538663395094556 acc: 0.2538663395094556
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.11it/s]


For index 14: precision: 0.5087291247994694 recall: 0.2563142664977257 acc: 0.2563142664977257
--------------------------------


8it [00:01,  4.70it/s]


For test: precision: 0.35815083235502243 recall: 0.23888385482132435 acc: 0.23888385482132435
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 12.92it/s]


For index 15: precision: 0.4972354364118829 recall: 0.25413571216708775 acc: 0.25413571216708775
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.18it/s]


For index 16: precision: 0.502839968200551 recall: 0.2562616453413537 acc: 0.2562616453413537
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.19it/s]


For index 17: precision: 0.5141904596362682 recall: 0.25996787520433895 acc: 0.25996787520433895
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.18it/s]


For index 18: precision: 0.5085417251318496 recall: 0.2583041170130897 acc: 0.2583041170130897
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.20it/s]


For index 19: precision: 0.5063528802171843 recall: 0.25451561450859567 acc: 0.25451561450859567
--------------------------------


8it [00:01,  4.80it/s]


For test: precision: 0.09064322896301746 recall: 0.16643052734434605 acc: 0.16643052734434605
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.20it/s]


For index 20: precision: 0.5200119550210356 recall: 0.26138558093187036 acc: 0.26138558093187036
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.19it/s]


For index 21: precision: 0.5299700120624328 recall: 0.26947284558948303 acc: 0.26947284558948303
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.13it/s]


For index 22: precision: 0.5325616033168028 recall: 0.27270563945963683 acc: 0.27270563945963683
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.07it/s]


For index 23: precision: 0.5209786904272654 recall: 0.2663251787966845 acc: 0.2663251787966845
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.12it/s]


For index 24: precision: 0.537222542066053 recall: 0.2709663522657969 acc: 0.2709663522657969
--------------------------------


8it [00:01,  4.75it/s]


For test: precision: 0.0974681917577982 recall: 0.17337393201887608 acc: 0.17337393201887608
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.07it/s]


For index 25: precision: 0.5271403898565185 recall: 0.2692126574007091 acc: 0.2692126574007091
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.06it/s]


For index 26: precision: 0.5264718082566925 recall: 0.26829549576470396 acc: 0.26829549576470396
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.16it/s]


For index 27: precision: 0.5401614477598904 recall: 0.2817792976336763 acc: 0.2817792976336763
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.11it/s]


For index 28: precision: 0.5360113129217103 recall: 0.2715060945032843 acc: 0.2715060945032843
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.06it/s]


For index 29: precision: 0.546879250383535 recall: 0.2764273668016424 acc: 0.2764273668016424
--------------------------------


8it [00:01,  4.74it/s]


For test: precision: 0.1007743226364255 recall: 0.17458654381334782 acc: 0.17458654381334782
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.15it/s]


For index 30: precision: 0.536917456333211 recall: 0.26951891339279166 acc: 0.26951891339279166
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.19it/s]


For index 31: precision: 0.5498716424632546 recall: 0.27158006389212136 acc: 0.27158006389212136
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.20it/s]


For index 32: precision: 0.5397326143964237 recall: 0.2644402275713074 acc: 0.2644402275713074
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.17it/s]


For index 33: precision: 0.5331396777799587 recall: 0.26834242534361136 acc: 0.26834242534361136
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.08it/s]


For index 34: precision: 0.5694471952930981 recall: 0.2851633065633032 acc: 0.2851633065633032
--------------------------------


8it [00:01,  4.75it/s]


For test: precision: 0.11860364209860563 recall: 0.1760083418339491 acc: 0.1760083418339491
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.16it/s]


For index 35: precision: 0.5382016161024965 recall: 0.26744665733433715 acc: 0.26744665733433715
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.07it/s]


For index 36: precision: 0.5527850250337297 recall: 0.2779706336360499 acc: 0.2779706336360499
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.07it/s]


For index 37: precision: 0.5495932455982594 recall: 0.2747056173123666 acc: 0.2747056173123666
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.09it/s]


For index 38: precision: 0.538944985840889 recall: 0.2663845404587834 acc: 0.2663845404587834
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.09it/s]


For index 39: precision: 0.5552817805121276 recall: 0.2813653577341149 acc: 0.2813653577341149
--------------------------------


8it [00:01,  4.71it/s]


For test: precision: 0.1212407685816288 recall: 0.16827768832445145 acc: 0.16827768832445145
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.15it/s]


For index 40: precision: 0.5518396368978039 recall: 0.2750752024007159 acc: 0.2750752024007159
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.19it/s]


For index 41: precision: 0.5495548184146944 recall: 0.27940943753285125 acc: 0.27940943753285125
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.18it/s]


For index 42: precision: 0.5437548615482469 recall: 0.2718561691114839 acc: 0.2718561691114839
--------------------------------


100%|██████████| 302/302 [00:22<00:00, 13.16it/s]


For index 43: precision: 0.5555935346626288 recall: 0.28226286747696383 acc: 0.28226286747696383
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.12it/s]


For index 44: precision: 0.5607600404529383 recall: 0.28245437389474043 acc: 0.28245437389474043
--------------------------------


8it [00:01,  4.80it/s]


For test: precision: 0.10219340492039919 recall: 0.13532843999564648 acc: 0.13532843999564648
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.12it/s]


For index 45: precision: 0.5506511156922145 recall: 0.2775328324626613 acc: 0.2775328324626613
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.11it/s]


For index 46: precision: 0.5512609451713152 recall: 0.274479006225897 acc: 0.274479006225897
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.11it/s]


For index 47: precision: 0.5679829725563921 recall: 0.2829793150357853 acc: 0.2829793150357853
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.11it/s]


For index 48: precision: 0.5634219154024755 recall: 0.2871630335140307 acc: 0.2871630335140307
--------------------------------


100%|██████████| 302/302 [00:23<00:00, 13.12it/s]


For index 49: precision: 0.5817879694976554 recall: 0.3032151830709533 acc: 0.3032151830709533
--------------------------------


8it [00:01,  4.78it/s]

For test: precision: 0.10868454165756702 recall: 0.1326053524389863 acc: 0.1326053524389863
--------------------------------





In [7]:
from torchmetrics.functional import precision
preds  = torch.tensor([[1, 0], [1, 0], [1, 0], [1, 0]])
target = torch.tensor([[1, 0], [1, 0], [1, 0], [1, 0]])
print(precision(preds, target, task="multiclass", average='micro', num_classes=2))
# precision(preds, target, average='micro')

tensor(1.)
