In [None]:
from utils import *

In [None]:
# params
lr = 0.02
batch_size = 65536
# batch_size = None
quick_run = {
'max_epochs': None,
# 'limit_train_batches': 0.1,
# 'limit_val_batches': 0.1,
# 'limit_test_batches': 0.1,
}
fast_dev_run_kwargs = {'fast_dev_run': True, 'enable_checkpointing': False}
overfit_batches_kwargs = {'overfit_batches': True, 'enable_checkpointing': False}
large_model = {'precision': "16-mixed"}
grad_accum = {'accumulate_grad_batches': 7}
resume_training = {'ckpt_path': 'path/to/ckpt'}
arcitecture_name = 'dnn' # linear_regression | dnn
run_name = arcitecture_name
experiemnt_name = "mnist-classifier"
dir_artifacts = dir_artifacts/arcitecture_name

In [None]:
class LinearRegression(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.fc = nn.Linear(n_in, n_out)
    
    def forward(self, x):
        return self.fc(x)


class DNN(nn.Module):
    def __init__(self, n_in:int, h1: int, h2: int, n_out:int, dropout: float):
        super().__init__()
        self.fc1 = nn.Linear(n_in, h1)
        self.bn1 = nn.BatchNorm1d(h1)
        self.fc2 = nn.Linear(h1, h2)
        self.bn2 = nn.BatchNorm1d(h2)
        self.fc3 = nn.Linear(h2, n_out)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x


class MNISTClassifier(L.LightningModule):
    def __init__(self, model, lr:float=2e-2):
        super().__init__()
        self.model = model
        self.loss_func = F.cross_entropy
        self.lr = lr
        self.example_input_array = torch.randn(5, 784)
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model(x)

    def _step(self, batch, idx, set_name: str):
        x,y=batch
        x = x.view(x.size(0), -1)
        pred = self(x)
        loss = self.loss_func(pred, y)
        self.log(f'{set_name}_loss', loss, prog_bar=True)
        acc = self.accuracy(pred, y)
        self.log(f'{set_name}_acc', acc, prog_bar=True)
        return loss

    def training_step(self, batch, idx):
        return self._step(batch, idx, 'train')

    def validation_step(self, batch, idx):
        return self._step(batch, idx, 'valid')

    def test_step(self, batch, idx):
        return self._step(batch, idx, 'test')

    def predict_step(self, batch, idx,  dataloader_idx=0):
        return self(batch[0])

    def configure_optimizers(self):
        return optim.Adam(params=self.parameters(), lr=self.lr)
        



def plot_samples(preds, x, y, losses, idxs, rows=5, cols=6, path: Path | str = None):
    preds, x, y, losses = map(lambda t:t.detach().numpy(), (preds, x, y, losses))
    fig, axes = plt.subplots(rows, cols, figsize=(20, 16))
    for i in range(rows):
        for j in range(cols):
            idx = idxs[(i * cols) + j]
            label = f'actual={y[idx]:.0f},pred={preds[idx]:.0f},loss={losses[idx]:.2f}'        
            ax = axes[i, j]
            ax.imshow(x[idx].reshape(28, 28), cmap='gray')
            ax.set_title(label)
            ax.axis('off')
    plt.suptitle("Predictions with highest loss", fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.94)
    if path:
        fig.savefig(Path(path).as_posix())
        plt.close()
    return fig


def plot_top_losses(preds, x, y, losses, **kwargs):
    _, idxs = losses.topk(100)
    return plot_samples(preds=preds, x=x, y=y, losses=losses, idxs=idxs, **kwargs)
    

def plot_random_samples(preds, x, y, losses, **kwargs):
    idxs = random.sample(range(len(x)), k=100)
    return plot_samples(preds=preds, x=x, y=y, losses=losses, idxs=idxs, **kwargs)

In [None]:
mlflow.set_tracking_uri('file://' + dir_mlruns.as_posix())
mlflow.set_experiment(experiment_name=experiemnt_name)
mlflow.pytorch.autolog()
with mlflow.start_run(run_name=run_name) as run:
    data = MNISTDataModule(dir_mnist.as_posix())
    arcitectures = {
        'linear_regression': LinearRegression(784, 10),
        'dnn': DNN(784, 1024, 256, 10, dropout=0.5),
    }
    arcitecture = arcitectures[arcitecture_name]
    model = MNISTClassifier(model=arcitecture)
    callbacks = [
        ModelCheckpoint(every_n_epochs=2),
        EarlyStopping(monitor="valid_loss"),
        StochasticWeightAveraging(swa_lrs=1e-2),
    ]
    mlf_logger = MLFlowLogger(
        experiment_name=experiemnt_name,
        run_id=mlflow.active_run().info.run_id,
        log_model=True,
        tracking_uri=uri_mlruns
    )
    trainer = L.Trainer(callbacks=callbacks, logger=mlf_logger, **quick_run)
    tuner = Tuner(trainer)
    if lr:
        tuned_lr = lr
    else:
        lr_finder = tuner.lr_find(model, datamodule=data)
        fig = lr_finder.plot(suggest=True)
        fig.savefig((dir_artifacts/'lr_finder.png').as_posix())
        plt.close()
        tuned_lr = lr_finder.suggestion()
        print(f'tuned_lr={tuned_lr:.3f}')

    if batch_size:
        tuned_batch_size = batch_size
    else:
        tuned_batch_size = tuner.scale_batch_size(model, datamodule=data, mode='power')
        print(f'tuned_batch_size={tuned_batch_size:,}')

    model = MNISTClassifier(model=arcitecture, lr=tuned_lr)
    data = MNISTDataModule(dir_mnist.as_posix(), batch_size=tuned_batch_size)
    trainer.fit(model, datamodule=data)
    trainer.save_checkpoint((dir_artifacts / "final_model.ckpt").as_posix())

    # calculating loss
    trainer.predict(model, datamodule=data)
    losses = Tensor([F.cross_entropy(model(x[None,]), y[None,]) for x, y in tqdm(data.ds_predict)])
    ds = data.ds_predict
    preds = model(ds.x).argmax(1)
    # plotting
    fig = plot_top_losses(preds, ds.x, ds.y, losses, path=dir_artifacts/'predictions_with_highest_loss.png')
    fig = plot_random_samples(preds, ds.x, ds.y, losses, path=dir_artifacts/'predictions_sample.png')

    trainer.test(model, data)
    mlflow.log_artifacts(dir_artifacts.as_posix())

mlflow.pytorch.autolog(disable=True)
launch_mlflow_ui(uri=uri_mlruns, run=run)