# This is a sample Jupyter Notebook

Below is an example of a code cell. 
Put your cursor into the cell and press Shift+Enter to execute it and select the next one, or click 'Run Cell' button.

Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.

To learn more about Jupyter Notebooks in PyCharm, see [help](https://www.jetbrains.com/help/pycharm/ipython-notebook-support.html).
For an overview of PyCharm, go to Help -> Learn IDE features or refer to [our documentation](https://www.jetbrains.com/help/pycharm/getting-started.html).

In [None]:
import torch
import torch.nn as nn

In [None]:
print(torch.backends.mps.is_available())

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps") # Apple GPU
    print("Using MPS device")
else:
    device = torch.device("cpu") # Defaults to CPU
    print("MPS device not found, using CPU")

# Example: Move a tensor or model to the MPS device
x = torch.ones(5, device=device)
class TwoBranch(nn.Module):
    def __init__(self):
        super().__init__()
        model = nn.LSTM(input_size=3, hidden_size=5, batch_first=False)
        self.left = model
        self.right = model
        self.combine = nn.Linear(10, 10)

    def forward(self, x_left, x_right):
        _, (l_h, _) = self.left(x_left)
        _, (r_h, _) = self.right(x_right)

        l = l_h[-1]  # last layer, shape: (batch, hidden_size)
        r = r_h[-1]

        cat = torch.cat([l, r], dim=-1)  # shape: (batch, 10)
        return self.combine(cat)          # shape: (batch, 10) -> Linear(10,1) -> (batch,1)
class BigModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch = TwoBranch()
        self.rest = nn.Sequential(
            nn.Linear(10, 8),
            nn.ReLU(),
            nn.Linear(8, 5),
            nn.ReLU(),
            nn.Linear(5, 1)
        )

    def forward(self, x_left, x_right):
        x = self.branch(x_left, x_right)
        return self.rest(x)
model2=BigModel()
model2.to(device)



In [None]:
batch = 32
seq_len = 64

x_left  = torch.randn(seq_len, batch, 3, device=device)
x_right = torch.randn(seq_len, batch, 3, device=device)

# labels depend on problem:
y = torch.randn(batch, 1, device=device)  # regression

In [None]:
import pandas as pd
import glob
files = glob.glob("data/*.json")
people = []
for file in files:
    df = pd.read_json(file)
    drawings = []
    for drawing in df.values:
        points = drawing[0][0]["points"]
        triples = [(point["x"], point["y"], point["time"]) for point in points]
        drawings.append(triples)
    people.append(drawings)


data = people


In [None]:
import json
from pathlib import Path

out_path = Path("output/data.json")

out_path.write_text(json.dumps(data, indent=2))

In [None]:
out_path = Path("output/data.json")
loaded = json.load(out_path.open())

In [None]:


training: list[list[list[tuple[int, int, int]]]] = loaded[0:2]
transformed = [(personID, drawings) for personID, drawings in training]

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model2.parameters(), lr=0.001)
for epoch in range(1000):
    optimizer.zero_grad()

    output = model2(x_left, x_right)
    loss   = criterion(output, y)

    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(epoch, loss.item())
model2