In [13]:
import numpy as np
from typing import Optional
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader



class SimCSE(nn.Conv2d):
  def __init__(
    self,
    in_channels: int,
    out_channels: int,
    kernel_size: int=3,
    padding: int=0,
    stride: int=1,
    groups: int=1,
    shared_weights: bool = False,
    log_p_init: float=.7,
    log_q_init: float=1.,
    log_p_scale: float=5.,
    log_q_scale: float=.3,
    alpha: Optional[float]=10,
    alpha_autoinit: bool=False,
    eps: float=1e-6,
):
    assert groups == 1 or groups == in_channels, " ".join([
        "'groups' needs to be 1 or 'in_channels' ",
        f"({in_channels})."])
    assert out_channels % groups == 0, " ".join([
        "The number of",
        "output channels needs to be a multiple of the number",
        "of groups.\nHere there are",
        f"{out_channels} output channels and {groups}",
        "groups."])

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.stride = stride
    self.groups = groups
    self.shared_weights = shared_weights

    if self.groups == 1:
      self.shared_weights = False

    super(SimCSE, self).__init__(
        self.in_channels,
        self.out_channels,
        kernel_size,
        bias=False,
        padding=padding,
        stride=stride,
        groups=self.groups)

    # Overwrite self.kernel_size created in the 'super' above.
    # We want an int, assuming a square kernel, rather than a tuple.
    self.kernel_size = kernel_size

    # Scaling weights in this way generates kernels that have
    # an l2-norm of about 1. Since they get normalized to 1 during
    # the forward pass anyway, this prevents any numerical
    # or gradient weirdness that might result from large amounts of
    # rescaling.
    self.channels_per_kernel = self.in_channels // self.groups
    weights_per_kernel = self.channels_per_kernel * self.kernel_size ** 2
    if self.shared_weights:
      self.n_kernels = self.out_channels // self.groups
    else:
      self.n_kernels = self.out_channels
    initialization_scale = (3 / weights_per_kernel) ** .5
    scaled_weight = np.random.uniform(
        low=-initialization_scale,
        high=initialization_scale,
        size=(
            self.n_kernels,
            self.channels_per_kernel,
            self.kernel_size,
            self.kernel_size)
    )
    self.weight = torch.nn.Parameter(torch.Tensor(scaled_weight))

    self.log_p_scale = log_p_scale
    self.log_q_scale = log_q_scale
    self.p = torch.nn.Parameter(torch.full(
        (1, self.n_kernels, 1, 1),
        float(log_p_init * self.log_p_scale)))
    self.q = torch.nn.Parameter(torch.full(
        (1, 1, 1, 1), float(log_q_init * self.log_q_scale)))
    self.eps = eps

    if alpha is not None:
      self.alpha = torch.nn.Parameter(torch.full(
          (1, 1, 1, 1), float(alpha)))
    else:
      self.alpha = None
    if alpha_autoinit and (alpha is not None):
      self.LSUV_like_init()

  def LSUV_like_init(self):
    BS, CH = 32, int(self.weight.shape[1]*self.groups)
    H, W = self.weight.shape[2], self.weight.shape[3]
    device = self.weight.device
    inp = torch.rand(BS, CH, H, W, device=device)
    with torch.no_grad():
        out = self.forward(inp)
        coef = (out.std(dim=(0, 2, 3)) + self.eps).mean()
        self.alpha.data *= 1.0 / coef.view_as(self.alpha)
    return

  def forward(self, inp: torch.Tensor) -> torch.Tensor:
    # Scale and transform the p and q parameters
    # to ensure that their magnitudes are appropriate
    # and their gradients are smooth
    # so that they will be learned well.
    p = torch.exp(self.p / self.log_p_scale)
    q = torch.exp(-self.q / self.log_q_scale)

    # If necessary, expand out the weight and p parameters.
    if self.shared_weights:
        weight = torch.tile(self.weight, (self.groups, 1, 1, 1))
        p = torch.tile(p, (1, self.groups, 1, 1))
    else:
        weight = self.weight

    return self.scs(inp, weight, p, q)

  def scs(self, inp, weight, p, q):
    # Normalize the kernel weights.
    weight = weight / self.weight_norm(weight)

    # Normalize the inputs and
    # Calculate the dot product of the normalized kernels and the
    # normalized inputs.
    cos_sim = F.conv2d(
        inp,
        weight,
        stride=self.stride,
        padding=self.padding,
        groups=self.groups,
    ) / self.input_norm(inp, q)

    # Raise the result to the power p, keeping the sign of the original.
    out = cos_sim.sign() * (cos_sim.abs() + self.eps) ** p

    # Apply learned scale parameter
    if self.alpha is not None:
      out = self.alpha.view(1, -1, 1, 1) * out
    return out

  def weight_norm(self, weight):
    # Find the l2-norm of the weights in each kernel.
    return weight.square().sum(dim=(1, 2, 3), keepdim=True).sqrt()

  def input_norm(self, inp, q):
    # Find the l2-norm of the inputs at each position of the kernels.
    # Sum the squared inputs over each set of kernel positions
    # by convolving them with the mock all-ones kernel weights.
    xnorm = F.conv2d(
        inp.square(),
        torch.ones((
            self.groups,
            self.channels_per_kernel,
            self.kernel_size,
            self.kernel_size)),
        stride=self.stride,
        padding=self.padding,
        groups=self.groups)

    # Add in the q parameter. 
    xnorm = (xnorm + self.eps).sqrt() + q
    outputs_per_group = self.out_channels // self.groups
    return torch.repeat_interleave(xnorm, outputs_per_group, axis=1)


class AbsPool(nn.Module):
  def __init__(self, pooling_module=None, *args, **kwargs):
    super(AbsPool, self).__init__()
    self.pooling_layer = pooling_module(*args, **kwargs)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    pos_pool = self.pooling_layer(x)
    neg_pool = self.pooling_layer(-x)
    abs_pool = torch.where(pos_pool >= neg_pool, pos_pool, -neg_pool)
    return abs_pool

MaxAbsPool2d = partial(AbsPool, nn.MaxPool2d)

class Network(nn.Module):
    def __init__(self):
        super().__init__()

        self.scs1 = SimCSE(
            in_channels=n_input_channels,
            out_channels=n_units_1,
            kernel_size=5,
            padding=0)
        self.pool1 = MaxAbsPool2d(kernel_size=2, stride=2, ceil_mode=True)

        self.scs2 = SimCSE(
            in_channels=n_units_1,
            out_channels=n_units_2,
            kernel_size=5,
            padding=1)
        self.pool2 = MaxAbsPool2d(kernel_size=2, stride=2, ceil_mode=True)

        self.scs3 = SimCSE(
            in_channels=n_units_2,
            out_channels=n_units_3,
            kernel_size=5,
            padding=1)
        self.pool3 = MaxAbsPool2d(kernel_size=4, stride=4, ceil_mode=True)
        self.out = nn.Linear(in_features=3600, out_features=len(classes))

    def n_params(self):
        n = 0
        for scs in [self.scs1, self.scs2, self.scs3]:
            n += (
                np.prod(scs.weight.shape) +
                np.prod(scs.p.shape) +
                np.prod(scs.q.shape))
        n += np.prod(self.out.weight.shape)
        return n

    def forward(self, t):
        t = self.scs1(t)
        t = self.pool1(t)

        t = self.scs2(t)
        t = self.pool2(t)
        
        t = self.scs3(t)
        t = self.pool3(t)

        t = t.view(t.size(0), -1)
        t = self.out(t)

        return t

In [28]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(59536, 128)
    self.fc2 = nn.Linear(128, 128)
    self.fc3 = nn.Linear(128, 128)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)  # flatten all dimensions except batch
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [57]:
import torch 
net = torch.load('/Users/andrewargeros/Documents/CDS-5950-Capstone/Models/convolution.pt',
    map_location=torch.device('cpu'))

In [50]:
net 

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=59536, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=128, bias=True)
)

In [38]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

test_transforms = A.Compose(
    [
        A.SmallestMaxSize(max_size=350),
        A.CenterCrop(height=256, width=256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

In [61]:

beer = '/Users/andrewargeros/Downloads/IMG_3448.JPG'
t = test_transforms(image=cv2.imread(beer))

b = torch.unsqueeze(t['image'], 0)
net.eval()
out = net(b)
prob = torch.nn.functional.softmax(out, dim=1)[0] * 100
_, indices = torch.sort(out, descending=True)

classes = ['Dark Malty Beers',
           'Fruit Beer',
           'IPA',
           'Light Beers',
           'nan',
           'NOT APPLICABLE',
           'Stouts']
[(classes[idx], prob[idx].item()) for idx in indices[0][:5]]

[('Dark Malty Beers', 30.145248413085938),
 ('Light Beers', 21.284114837646484),
 ('Stouts', 18.034154891967773),
 ('Fruit Beer', 15.659406661987305),
 ('IPA', 8.632295608520508)]

In [54]:
out

tensor([[  0.5031,   0.4296,  -0.0826,   0.9459, -12.7544,  -1.4035,  -0.1533,
         -12.5288, -12.7742, -12.7771, -13.3568, -11.8579, -14.3833, -12.9176,
         -12.5038, -12.2867, -12.4958, -13.1342, -13.4447, -13.3530, -13.4480,
         -12.6563, -12.3481, -13.5340, -12.4730, -13.4127, -12.8176, -13.2121,
         -13.2459, -13.6531, -12.9309, -12.9689, -12.7460, -14.0753, -13.5890,
         -12.7821, -12.7713, -13.1488, -13.5154, -13.0589, -13.3421, -12.0689,
         -13.7885, -13.1374, -13.2224, -12.5681, -12.8236, -13.3941, -12.6402,
         -13.3742, -13.4391, -13.9239, -13.5879, -12.5649, -12.2784, -12.0253,
         -12.2664, -12.4480, -13.2110, -11.9693, -12.2393, -13.8797, -13.1799,
         -11.3393, -13.8272, -13.1139, -13.5190, -12.8256, -13.8177, -14.0785,
         -12.5435, -13.6326, -12.6651, -12.1869, -13.6152, -13.3730, -13.4342,
         -13.4093, -12.8215, -12.8179, -13.4719, -13.5709, -13.7991, -13.6349,
         -13.2671, -12.8880, -13.2179, -13.9519, -13