In [None]:
! pip install torchvision



In [38]:
# Import libraries
# import matplotlib.pyplot as plt
import numpy as np
import pickle 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

# Load the data
dataset = pickle.load(open('ocr_insurance_dataset.pkl', 'rb'))

In [39]:
# Defining OCRModel class

class OCRModel(nn.Module):
    def __init__(self):
        super(OCRModel, self).__init__()
        
        # Define the image layer
        self.image_layer = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        # Define the classifier layer
        self.fc = nn.Sequential(
            nn.Linear(64*16*16, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )
    
    def forward(self, image, type_input):
        x = self.image_layer(image)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [40]:
# Preparing dataloader, model and parameters

dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

model = OCRModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [41]:
# Training model

for epoch in range(10):
    for (images, types), labels in dataloader:
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, types)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')



Epoch [1/10], Loss: 0.6863
Epoch [2/10], Loss: 0.6721
Epoch [3/10], Loss: 0.6482
Epoch [4/10], Loss: 0.7167
Epoch [5/10], Loss: 0.5054
Epoch [6/10], Loss: 0.4464
Epoch [7/10], Loss: 0.3759
Epoch [8/10], Loss: 0.2777
Epoch [9/10], Loss: 0.0850
Epoch [10/10], Loss: 0.1337
