## Concise Logistic Regression for Image Classification

- Shows a concise implementation of logistic regression for image classification
- Uses PyTorch

In [None]:
# imports
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, models, transforms
import os
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

# use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# download the data
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip hymenoptera_data.zip

In [None]:
# create data loaders

data_dir = 'hymenoptera_data'

# custom transformer to flatten the image tensors
class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        result = torch.reshape(img, self.new_size)
        return result

# transformations used to standardize and normalize the datasets
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        ReshapeTransform((-1,)) # flattens the data
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        ReshapeTransform((-1,)) # flattens the data
    ]),
}

# load the correspoding folders
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

# load the entire dataset; we are not using minibatches here
train_dataset = torch.utils.data.DataLoader(image_datasets['train'],
                                            batch_size=len(image_datasets['train']),
                                            shuffle=True)

test_dataset = torch.utils.data.DataLoader(image_datasets['val'],
                                           batch_size=len(image_datasets['val']),
                                           shuffle=True)

In [None]:
# build the LR model
class LR(nn.Module):
    def __init__(self, dim):
        super(LR, self).__init__()
        self.linear = nn.Linear(dim, 1)
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, x):
        x = self.linear(x)
        x = torch.sigmoid(x)
        return x 

In [None]:
# predict function
def predict(yhat, y):
    yhat = yhat.squeeze()
    y = y.unsqueeze(0) 
    y_prediction = torch.zeros(y.size()[1])
    for i in range(yhat.shape[0]):
        if yhat[i] <= 0.5:
            y_prediction[i] = 0
        else:
            y_prediction[i] = 1
    return 100 - torch.mean(torch.abs(y_prediction - y)) * 100

In [None]:
# model config
dim = train_dataset.dataset[0][0].shape[0]

lrmodel = LR(dim).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(lrmodel.parameters(), lr=0.0001)

In [None]:
# training the model
costs = []

for ITER in range(200):
    lrmodel.train()
    x, y = next(iter(train_dataset))
    test_x, test_y = next(iter(test_dataset))

    # forward
    yhat = lrmodel.forward(x.to(device))

    cost = criterion(yhat.squeeze(), y.type(torch.FloatTensor))
    train_pred = predict(yhat, y)

    # backward
    optimizer.zero_grad()
    cost.backward()
    optimizer.step()
    
    # evaluate
    lrmodel.eval()
    with torch.no_grad():
        yhat_test = lrmodel.forward(test_x.to(device))
        test_pred = predict(yhat_test, test_y)

    if ITER % 10 == 0:
        costs.append(cost)

    if ITER % 10 == 0:
        print("Cost after iteration {}: {} | Train Acc: {} | Test Acc: {}".format(ITER, 
                                                                                    cost, 
                                                                                    train_pred,
                                                                                    test_pred))
   

### References
- [A Logistic Regression Model from Scratch](https://colab.research.google.com/drive/1iBoJ0kngkOthy7SgVaVQA1aHEROt5mra?usp=sharing)