In [3]:
!pip install wandb
!pip install gcsfs
!pip install catboost

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting catboost
  Downloading catboost-1.1.1-cp38-none-manylinux1_x86_64.whl (76.6 MB)
[K     |████████████████████████████████| 76.6 MB 1.2 MB/s 
Installing collected packages: catboost
Successfully installed catboost-1.1.1


In [6]:
import torch
import gcsfs

import google.auth
from google.colab import auth

from torch import nn, optim
import torch.nn.functional as F


# connect to google cloud storage
auth.authenticate_user()
credentials, _ = google.auth.default()
fs = gcsfs.GCSFileSystem(project="thesis", token=credentials)
fs_prefix = "gs://"

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
print(net)


Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [8]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


In [9]:
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save(
    {
        "epoch": EPOCH,
        "model_state_dict": net.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": LOSS,
    },
    PATH,
)


In [11]:

# https://stackoverflow.com/a/72511896/5755604
fs = gcsfs.GCSFileSystem(
    project="flowing-mantis-239216")
with fs.open(
    "gs://thesis-bucket-option-trade-classification/models/model.pt", "wb"
) as f:
    torch.save(
        {
            "epoch": EPOCH,
            "model_state_dict": net.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": LOSS,
        },
        f,
    ) # type: ignore


In [12]:
from catboost import CatBoostClassifier
import os 

# https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
clf = CatBoostClassifier(iterations=5).fit([[1,2],[3,4]],[5,6])
clf.save_model("temp.cbm", format="cbm")
fs.put_file("temp.cbm", "gs://thesis-bucket-option-trade-classification/models/model.cbm")
os.remove("temp.cbm")

Learning rate set to 0.093437
0:	learn: 0.6859798	total: 53.1ms	remaining: 213ms
1:	learn: 0.6789196	total: 53.2ms	remaining: 79.8ms
2:	learn: 0.6719564	total: 53.7ms	remaining: 35.8ms
3:	learn: 0.6650971	total: 53.7ms	remaining: 13.4ms
4:	learn: 0.6583346	total: 53.8ms	remaining: 0us


In [13]:
from catboost import CatBoostClassifier
import os 

clf = CatBoostClassifier(iterations=5).fit([[1,2],[3,4]],[5,6])

# https://stackoverflow.com/a/74067334/5755604
with fs.open(
    "gs://thesis-bucket-option-trade-classification/models/cat.cbm", "wb"
) as f:
    f.write(clf._serialize_model())


Learning rate set to 0.093437
0:	learn: 0.6859798	total: 548us	remaining: 2.19ms
1:	learn: 0.6789196	total: 956us	remaining: 1.44ms
2:	learn: 0.6719564	total: 1.34ms	remaining: 892us
3:	learn: 0.6650971	total: 1.42ms	remaining: 356us
4:	learn: 0.6583346	total: 2.34ms	remaining: 0us


In [21]:
import time
import datetime
import copy
import numpy as np
from dataclasses import dataclass, field
from typing import List, Any
import warnings


class Callback:
    """
    Abstract base class used to build new callbacks.
    """

    def __init__(self):
        pass

    def set_params(self, params):
        self.params = params

    def on_epoch_end(self, epoch, epochs, train_loss, valid_loss):
        pass

    def on_train_end(self, model:Any):
        pass


@dataclass
class CallbackContainer:
    """
    Container holding a list of callbacks.
    """

    callbacks: List[Callback] = field(default_factory=list)

    def append(self, callback):
        self.callbacks.append(callback)

    def set_params(self, params):
        for callback in self.callbacks:
            callback.set_params(params)

    def on_epoch_end(self, epoch, epochs, train_loss, valid_loss):
        for callback in self.callbacks:
            callback.on_epoch_end(epoch, epochs, train_loss, valid_loss)


    def on_train_end(self, model):
        for callback in self.callbacks:
            callback.on_train_end(model)


In [15]:
class PrinterCallback(Callback):
    def on_epoch_end(self, epoch, epochs, train_loss, valid_loss):
        print(f'[{type(self).__name__}]: End of Epoch. Epoch: {epoch}') 

In [16]:
import wandb
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Union

In [17]:
import wandb
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [27]:
class SaveCallback(Callback):

    def __init__(self,  wandb_kwargs: Optional[Dict[str, Any]] = None):
        self._wandb_kwargs = wandb_kwargs or {}

        # create wandb run if it doesn't exist
        self._run = wandb.run
        if not self._run:
            self._run = self._initialize_run()


    def _initialize_run(self) -> "wandb.sdk.wandb_run.Run":
        """Initializes Weights & Biases run."""
        run = wandb.init(**self._wandb_kwargs)
        if not isinstance(run, wandb.sdk.wandb_run.Run):
            raise RuntimeError(
                "Cannot create a Run. "
                "Expected wandb.sdk.wandb_run.Run as a return. "
                f"Got: {type(run)}."
            )
        return run



    def on_train_end(self, model):
        print(f'[{type(self).__name__}]: Save Model. {model}')

        model_artifact = wandb.Artifact(name="blabla", type="model")
        model_artifact.add_reference("gs://thesis-bucket-option-trade-classification/models/cat.cbm", name="bla")
        self._run.log_artifact(model_artifact)

In [28]:
# https://towardsdatascience.com/5-minutes-data-science-design-patterns-i-callback-b5c0738be277
def train_with_callback(callback=None):
    n_epochs = 3
    n_batches = 2
    loss = 20
    clf = CatBoostClassifier(iterations=5).fit([[1,2],[3,4]],[5,6])
    for epoch in range(n_epochs):
        callbacks.on_epoch_end(epoch, n_epochs, 5, 6)
    callbacks.on_train_end(clf)
    return loss

run = wandb.init(project="thesis", entity="fbv")

wandb_kwargs={"project": "thesis", "entity": "fbv"}
callbacks = CallbackContainer([PrinterCallback(),SaveCallback()])

train_with_callback(callback=callbacks)

Learning rate set to 0.093437
0:	learn: 0.6859798	total: 179us	remaining: 717us
1:	learn: 0.6789196	total: 264us	remaining: 396us
2:	learn: 0.6719564	total: 333us	remaining: 222us
3:	learn: 0.6650971	total: 2.06ms	remaining: 514us
4:	learn: 0.6583346	total: 2.15ms	remaining: 0us
[PrinterCallback]: End of Epoch. Epoch: 0
[PrinterCallback]: End of Epoch. Epoch: 1
[PrinterCallback]: End of Epoch. Epoch: 2
[SaveCallback]: Save Model. <catboost.core.CatBoostClassifier object at 0x7febf1055250>


20