In [3]:
import torch
from torch import nn
import pytorch_lightning as pl
import torchmetrics
import seaborn as sns
import matplotlib.pyplot as plt
import wandb
import numpy as np
import pdb


In [4]:
resnet50 = torch.hub.load(
            "pytorch/vision:v0.9.0",
            "resnet50",
    
        )

Using cache found in /home/juliu/.cache/torch/hub/pytorch_vision_v0.9.0


In [9]:
class ResNet_withMeta(pl.LightningModule):
    def __init__(self, args, num_classes=1000):
        super().__init__()

        self.img_backbone = torch.hub.load(
            "pytorch/vision:v0.9.0",
            "resnet50",
            weights="ResNet50_Weights.IMAGENET1K_V1",
        )
        self.img_backbone.fc = nn.Identity()

        self.meta_backbone = nn.Linear(16, 16)
        self.meta_backbonev2 = nn.Sequential(
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU()
        )
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(2048 + 16, num_classes)
        self.args = args

        self.accuracy1 = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
        self.accuracy3 = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes, top_k=3
        )
        self.f1_score = torchmetrics.F1Score(
            task="multiclass", num_classes=num_classes, average="micro"
        )


    def forward(self, img, meta):
        features_img = self.img_backbone(img)  # [N, 2048]
        features_meta = self.meta_backbonev2(meta)  # [N, n_features]
        features = torch.cat((features_img, features_meta), 1)
        features = self.relu(features)

        return self.classifier(features)

    def training_step(self, batch, batch_idx):
        img, meta, y = batch
        y_hat = self(img, meta)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        self.log("train_loss", loss)
        pdb.set_trace()
        return loss

    def validation_step(self, batch, batch_idx):
        img, meta, y = batch
        y_hat = self(img, meta)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        self.log("val_loss", loss)
        self.log("val_acc1", self.accuracy1(y_hat, y))
        self.log("val_acc3", self.accuracy3(y_hat, y))
        self.log("val_f1", self.f1_score(y_hat, y))

    def on_validation_epoch_end(self):
        pass

    def configure_optimizers(self):
        # Choose optimizer from dict
        from itertools import chain

        dictOptimizer = {
            "adam": torch.optim.Adam(
                chain(
                    self.img_backbone.fc.parameters(),
                    self.classifier.parameters(),
                    self.meta_backbone.parameters(),
                ),
                lr=self.args.lr,
                weight_decay=self.args.weight_decay,
            ),
            "sgd": torch.optim.SGD(
                chain(
                    self.img_backbone.fc.parameters(),
                    self.classifier.parameters(),
                    self.meta_backbone.parameters(),
                ),
                lr=self.args.lr,
                weight_decay=self.args.weight_decay,
                momentum=self.args.momentum,
            ),
        }
        optimizer = dictOptimizer[self.args.optimizer]
        # optimizer = torch.optim.Adam(
        #    self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay
        # )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.args.epochs
        )  # StepLR -> cosine
        return [optimizer], [scheduler]


In [12]:
resnet = ResNet_withMeta(args=None, num_classes=999)

Using cache found in /home/juliu/.cache/torch/hub/pytorch_vision_v0.9.0


In [13]:
resnet

ResNet_withMeta(
  (img_backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia

In [16]:
inp = torch.rand(1, 3, 224, 224)

In [29]:
meta = torch.tensor([[1, 0, 1, 0, 0,0,0,0,0, 0, 0, 0, 0,0,1,0]], dtype=torch.float)
meta.shape

torch.Size([1, 16])

In [32]:
out = resnet.forward(inp, meta)

In [37]:
softmax = nn.Softmax(dim=1)

In [38]:
torch.argmax(softmax(out))

tensor(495)