In [None]:
import serial
import MySQLdb
import time
import cv2 as cv2
import torch
import torch.nn as nn
from torch.autograd import Variable
from io import open
from PIL import Image
import torchvision
from torchvision import datasets,transforms

import numpy as np
# from tfr_nn_custom_pretrained_model import ConvNet

In [None]:
class ConvNet(nn.Module):
    def __init__(self, num_classes=3):
        super(ConvNet, self).__init__()

        # Input shape = (64, 3, 30, 30)
        # Output size after convolutional layer = (w-f+2p)/s + 1 = (30-3+2)/1 + 1 = 30
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
        #Input shape = (64, 12, 30, 30)
        self.bn1 = nn.BatchNorm2d(num_features=12)
        self.relu1 = nn.ReLU()

        # Input shape = (64, 12, 30, 30)
        # Output size after max pooling = 30/2 = 15
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        # Input shape = (64, 12, 15, 15)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=20, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(in_channels=20, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=32)
        self.relu3 = nn.ReLU()
        #shape = (64, 32, 15, 15)

        self.fc = nn.Linear(in_features=15*15*32, out_features=num_classes)

    def forward(self, input):
        output = self.conv1(input)
        output = self.bn1(output)
        output = self.relu1(output)

        output = self.maxpool1(output)

        output = self.conv2(output)
        output = self.relu2(output)

        output = self.conv3(output)
        output = self.bn3(output)
        output = self.relu3(output)

        #reshaping the output to feed into the fully connected layer
        output = output.view(-1, 15*15*32)

        output = self.fc(output)

        return output

In [None]:
#function to detect an item
def detectItem():

    classes = ['Humidity', 'Bluetooth', 'Transistor']
    cam_port = 0
    cam = cv2.VideoCapture(cam_port)
    result, image = cam.read()

    if result:

        #transform captured image
        mean = np.array([0.485,0.456,0.406])
        std = np.array([0.229,0.224,0.225])

        data_transforms = transforms.Compose([
                transforms.Resize((30, 30)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
        ])

        # Loading the model
        checkpoint = torch.load('bestFineTuneModel.model')
        model = ConvNet(num_classes=3)
        model.load_state_dict(checkpoint)
        model.eval()

        image_tensor = data_transforms(image).float()
        image_tensor = image_tensor.unsqueeze_(0)

        if torch.cuda.is_available():
            image_tensor.cuda()

        input = Variable(image_tensor)

        # Make the prediction
        output = model(input)

        pred_index = output.data.numpy.argmax()

        prediction = classes[pred_index]

    else:
        print("Error: Failed to capture image from the camera")
    return prediction

In [None]:
# Connect to MySQL Database
dbConn=MySQLdb.connect("localhost","root","","rfid_read") or die("Could not connect to DB")
cursor=dbConn.cursor()

device = 'COM9'

try:
    print("Trying...",device)
    arduino=serial.Serial(device,9600)


except:
    print("Failed to connect on",device)
while True:
    time.sleep(1)

    try:

        data=arduino.readline()
        pieces=data.decode()
        print(pieces)
        try:
            cursor=dbConn.cursor();
            cursor.execute(""" SELECT * FROM user WHERE RFID LIKE %s """,(pieces,))
            row = cursor.fetchone()
            if row is None:
                print("It is not a valid ID Card")
                err = "  Invalid ID  "
                err = err + '\r'
                arduino.write(err.encode())
            else:
                #pieces = pieces+'\r'
                returnID = row[0]+'\r'
                arduino.write(returnID.encode())
                cursor.execute(""" INSERT INTO rfid_data (ID, Member_ID) VALUES (NULL,%s) """,(pieces,))
                dbConn.commit()
                orderID = int(cursor.lastrowid);
                # rfid has been inserted, check if there are any outstanding orders (orders that have not been returned yet in 7 days)
                # SELECT ID FROM `rfid_data` where Member_ID='(rfid)' and state='0' and DATEDIFF(CURRENT_DATE,date)>7;
                # returns the order ids of the outstanding orders

                while True:
                    # loop has to exist here.
                    while(arduino.inWaiting()==0):
                        pass
                    opt_data = arduino.readline()
                    option = opt_data.decode()
#                     opt_data = str(opt_data,'uft-8')
#                     opt_data = opt_data.strip('\r\n')
                    print("option: "+str(option))
                    # 1 : borrow  2 : return # : exit
                    option = option.strip('\r\n')
                    if option == "1":
                        # borrow item
                        while(arduino.inWaiting()==0):
                            pass
                        item_data = arduino.readline()
#                         brw_data = str(brw_data,'uft-8')
#                         brw_data = brw_data.strip('\r\n')
                        item_data = item_data.decode()
                        print("item = "+str(item_data))

                        item = int(item_data)
                        if item<=3:
                            while(arduino.inWaiting()==0):
                                pass
                            qty_data = arduino.readline()
                            qty_data = qty_data.decode()
                            quantity = int(qty_data)

                            # sql = "SELECT stock FROM `inventory` WHERE id=item;"
                            cursor.execute("SELECT stock FROM `inventory` WHERE id = %s",(item,))
                            stock = cursor.fetchone()
                            print(str(stock[0]))
                            if (stock[0]-quantity)>=0:

                                msg_ok = "ok"
                                msg_ok = str(msg_ok) + '\r'
                                arduino.write(msg_ok.encode())

                                # sql = "INSERT INTO `orders` (`ID`, `item_id`, `quantity`) VALUES ('"+orderID+"', '"+item+"', '"+quantity+"')"
                                cursor.execute("INSERT INTO `orders` (`ID`, `item_id`, `quantity`) VALUES ('%s', '%s', '%s')",(orderID,item,quantity,))
                                # INSERT INTO `orders` (`ID`, `item_id`, `quantity`) VALUES ('(orderID)', '(item)', '(quantity)');
                                dbConn.commit()
                                #update stock on inventory
                                #update inventory join orders on (inventory.ID=orders.item_id) set stock=stock-orders.quantity where orders.ID='(orderID)'
                                cursor.execute("update inventory join orders on (inventory.ID=orders.item_id) set stock=stock-orders.quantity where orders.ID='%s' AND inventory.ID='%s'",(orderID,item))
                                dbConn.commit()
                                #update order to state = 0
                                #UPDATE `rfid_data` SET `state` = '0' WHERE `rfid_data`.`ID` = (orderID);
                                cursor.execute("UPDATE `rfid_data` SET `state` = '0' WHERE `rfid_data`.`ID` ='%s'",(orderID,))
                                dbConn.commit()
                            else:
                                stk_ch = stock[0]
                                stk_ch = str(stk_ch) + '\r'
                                arduino.write(stk_ch.encode())

                    elif option == "2":
                        # return item
                        while(arduino.inWaiting()==0):
                            pass
                        item_placed = arduino.readline()
                        item_placed = item_placed.decode()
                        item_placed = item_placed.strip('\r\n')
                        if item_placed == "*":
                            print(item_placed)
                            #item is recognised here
                            prediction = detectItem()
                            #print(prediction)
                            query = "update inventory set stock=stock+1 where inventory.item='"+prediction+"'"
                            #print(query)
                            prediction = prediction + '\r'
                            arduino.write(prediction.encode())

                            servo_msg = arduino.readline()
                            servo_msg = servo_msg.decode()
                            servo_msg = servo_msg.strip('\r\n')
                            print(servo_msg)
                            #update stock on inventory
                            #update inventory join orders on (inventory.ID=orders.item_id) set stock=stock+orders.quantity where orders.ID='(order id)'
                            #cursor.execute("""update inventory set stock=stock+1 where inventory.item='%s'""",(prediction,))
                            cursor.execute(query)
                            dbConn.commit()
                            #update order to state = 1
                            #UPDATE `rfid_data` SET `state` = '1' WHERE `rfid_data`.`ID` = (orderID);
                            cursor.execute("UPDATE `rfid_data` SET `state` = '1' WHERE `rfid_data`.`ID` ='%s'",(orderID,))
                            dbConn.commit()
                    elif option == "#":
                        #cancel
                        # DELETE FROM rfid_data WHERE `rfid_data`.`ID` = (order) and state=NULL
                        #maybe first check if state==NULL then delete
#                         cursor.execute("DELETE FROM rfid_data WHERE `rfid_data`.`ID` = '%s' and state=NULL",(orderID,))
#                         dbConn.commit()
                        break


            cursor.close()
        except MySQLdb.IntegrityError:
            print("failed to insert data")
        finally:
            cursor.close()
    except Exception as e:
        print(e)
        print("Processing")
