### Importing libraries

In [82]:
import sys
sys.path.append('../')
import pymongo
import numpy as np
import chess
from multiprocessing import Pool
import csv
import os
from sklearn.preprocessing import StandardScaler
import pandas as pd
from sqlalchemy.orm import  Session
from tqdm import tqdm
from joblib import load
import math
from pymongo import MongoClient
import hashlib
import json


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

current_working_directory = os.getcwd()

# print output to the console
print(current_working_directory)

C:\Users\ethan\git\Full_Chess_App


# Importing my stuff

In [2]:
from Chess_Model.src.model.config.config import Settings
from Chess_Model.src.model.classes.sqlite.dependencies import  fetch_one_game_position, fetch_all_game_positions_rollup,get_rollup_row_count,board_to_GamePostition
from Chess_Model.src.model.classes.sqlite.models import GamePositions
from Chess_Model.src.model.classes.sqlite.database import SessionLocal
from Chess_Model.src.model.classes.cnn_bb_scorer import boardCnnEval

# Constants

In [3]:
metadata_key = 'metadata'
bitboards_key = 'positions_data'
results_key = 'game_results'

# Connect to MongoDB
mongo_url = "mongodb://localhost:27017/"
client = MongoClient(mongo_url)

# Create or switch to a database
db_name = "mydatabase"
db = client[db_name]

# Create or switch to a collection
main_collection = "main_collection"
collection = db["main_collection"]

validation_collection_key = "validation_data"
testing_collection_key = "testing_data"
training_collection_key = "training_data"

valid_collection = db[validation_collection_key]
test_collection = db[testing_collection_key]
train_collection = db[training_collection_key]

evaluator = boardCnnEval()

validation_ratio = 0.15
test_ratio = 0.15

In [4]:
collection.delete_many({})
valid_collection.delete_many({})
test_collection.delete_many({})
train_collection.delete_many({})

DeleteResult({'n': 10530, 'ok': 1.0}, acknowledged=True)

In [5]:
sample = []
with SessionLocal() as db:
    i = 0
    for game in fetch_all_game_positions_rollup(db=db,yield_size=1):
        sample.append(game)
        i += 1
        if i >= 1000:
            break

Exception ignored in: <generator object fetch_all_game_positions_rollup at 0x000001E3049E3100>
Traceback (most recent call last):
  File "C:\Users\ethan\AppData\Local\Temp\ipykernel_5496\1352021742.py", line 8, in <module>
RuntimeError: generator ignored GeneratorExit


In [6]:
sample

[<Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304cc8790>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304b93610>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304d07f90>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abca10>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abcd10>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abce50>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abd1d0>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abd350>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abd410>,
 <Chess_Model.src.model.classes.sqlite.dependencies.GamePositionWithWinBuckets at 0x1e304abd4d0>,
 <Chess_Model.src.mo

In [7]:
dataset = []
for game in sample:
    evaluator.setup_parameters_gamepositions(game=game)
    dataset.append(evaluator.get_board_scores())


In [8]:
evaluator.setup_parameters_gamepositions(game=sample[0])
f = evaluator.get_board_scores()

In [9]:
dataset[0]

{'metadata': [5,
  0,
  2,
  0,
  0,
  5,
  0,
  2,
  0,
  0,
  7,
  7,
  1,
  1,
  0,
  0,
  0,
  0,
  19,
  15,
  1.2666666666666666,
  0.7894736842105263,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  3,
  1,
  1,
  1,
  0,
  0,
  1,
  0],
 'positions_data': [551098122496,
  0,
  144115222435594240,
  0,
  0,
  2097152,
  9218374240305152,
  0,
  1152956688978935808,
  0,
  0,
  17179869184],
 'game_results': [0.0, 0.0, 1.0]}

In [10]:
def convert_large_ints(data):
    # Convert any large integers to strings to prevent OverflowError

    for i in range(0,len(data)):
        data[i] = str(data[i])
    return data

def get_hash_id(doc):
    """ Generate a hash value for the item ID and scale it to the range 0 to bins-1 """
    dict_string = json.dumps(doc, sort_keys=True)
    hash_object = hashlib.sha256(dict_string.encode())
    hex_dig = hash_object.hexdigest()
    return hex_dig

def game_to_doc_evaluation(game):
    evaluator.setup_parameters_gamepositions(game=game)
    board_scores = evaluator.get_board_scores()
    document = {
        "metadata": board_scores['metadata'],  # Ensure this does not contain large integers
        "positions_data": convert_large_ints(board_scores['positions_data']),
        "game_results": board_scores['game_results']
    }

    
    
    return(document)

def process_sqlite_boards_to_mongo(batch_size: int = 1):
    with SessionLocal() as db:
        row_count = get_rollup_row_count(db=db)
        batch = fetch_all_game_positions_rollup(yield_size=500, db=db)
        dataset = []  # List to accumulate serialized examples
        for game in tqdm(batch, total=row_count, desc="Processing Feature Data"):
            try:
                if game:
                    
                    document = game_to_doc_evaluation(game=game)

                    document['_id'] = get_hash_id(doc=document)
                    
                    dataset.append(document)



                    # Check if we've accumulated enough examples to write a batch
                    if len(dataset) >= batch_size:

                        collection.insert_many(dataset)
                        dataset = []  # Reset the list after writing
                else:
                    return 1
            except Exception as e:
                raise Exception(e)
        

process_sqlite_boards_to_mongo()


Processing Feature Data: 100%|â–ˆ| 15


1

In [11]:
validation_min = validation_ratio * 100
test_min = test_ratio*100 + validation_min

def collection_decider(hash_value: int):
    if hash_value < validation_min:
        return validation_collection_key
    elif hash_value < test_min:
        return testing_collection_key
    else:
        return training_collection_key
    

In [12]:
def mongoConnection(mongoUrl,dbName):
    try:
     # Fetching client using pymongo
        client = MongoClient(mongoUrl)
        # Holding our database
        mydatabase = client[dbName]
        return mydatabase
    except Exception as e:
        print("Error occured while connecting to database!")
        return False

def iteratingFunction(collectionName,mongoUrl=mongo_url,dbName=db_name,batch_size: int = 1000):
    try:
        db = mongoConnection(mongoUrl,dbName)
        # Fetching the collection
        collection = db[collectionName]
        batch = collection.find({}, batch_size = batch_size)

        for doc in batch:
            yield doc

    except Exception as e:
         print("Error occured!: ",e)

iteratingFunction(mongoUrl=mongo_url,dbName=db_name,collectionName=main_collection)

<generator object iteratingFunction at 0x000001E304C72C20>

In [13]:
def mongo_document_generator(yield_size=5):
    
    # specified_collection = db[collection_key]
    
    cursor = collection.find({}, no_cursor_timeout=True)
    
    # Fetch documents in batches as per the yield_size
    batch = []
    for document in cursor:
        batch.append(document)
        if len(batch) >= yield_size:
            yield batch
            batch = []
    
    # Yield any remaining documents if batch isn't empty
    if batch:
        yield batch

    # Close the cursor and connection when done
    cursor.close()
    # client.close()
    #Definitely fix this later or implement a cursor closer overall in script

In [14]:
def get_hash_ring_value(id, bins=100):
    hash_value = int(id, 16) % bins
    return hash_value
    
def shuffle_and_split_set(batch_size:int = 100, yield_size:int = 1):
    i = 0
    collections_and_keys = [[testing_collection_key, test_collection],
                            [validation_collection_key, valid_collection], 
                            [training_collection_key, train_collection]]
    collected_docs = {testing_collection_key:[],
                     validation_collection_key:[],
                     training_collection_key:[]}
    for doc in iteratingFunction(mongoUrl=mongo_url,dbName=db_name,collectionName=main_collection):
    # for docs in mongo_document_generator(yield_size=yield_size):
        # for doc in docs:
            
        bin = get_hash_ring_value(id=doc['_id'])
        
        collection = collection_decider(hash_value=bin)
        collected_docs[collection].append(doc)
        
        i += 1
        if i >= batch_size:
            
            for key,collection_client in collections_and_keys:
                if collected_docs[key] == []:
                    pass
                else:
                    collection_client.insert_many(collected_docs[key])
            
            collected_docs = {testing_collection_key:[],
                 validation_collection_key:[],
                 training_collection_key:[]}
            i = 0
            
shuffle_and_split_set(batch_size = 100, yield_size = 10)                

# C:\Users\ethan\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymongo\collection.py:1685: UserWarning: use an explicit session with no_cursor_timeout=True otherwise the cursor may still timeout after 30 minutes, for more info see https://mongodb.com/docs/v4.4/reference/method/cursor.noCursorTimeout/#session-idle-timeout-overrides-nocursortimeout
#   return Cursor(self, *args, **kwargs)
        

In [15]:
db = client["mydatabase"]
collection_stats = db.command("collstats", testing_collection_key)
print(f"Document Count: {collection_stats['count']}")
collection_stats = db.command("collstats", validation_collection_key)
print(f"Document Count: {collection_stats['count']}")
collection_stats = db.command("collstats", training_collection_key)
print(f"Document Count: {collection_stats['count']}")

Document Count: 2193
Document Count: 2277
Document Count: 10530


In [16]:
# Fetch the documents
n = 1
documents = train_collection.find({}, {'metadata': 1,'_id': 0}).limit(1)

# Print each document
for document in documents:
    print(document)

{'metadata': [5, 0, 2, 0, 0, 5, 0, 2, 0, 0, 7, 7, 1, 1, 0, 0, 0, 0, 18, 13, 1.3846153846153846, 0.7222222222222222, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 1]}


In [17]:
def calc_mongo_train(batch_size: int = 1000):
    collection_stats = db.command("collstats", training_collection_key)
    m = int(collection_stats['count']) 
    sample_doc = train_collection.find_one({})
    n = len(list(sample_doc['metadata']))
    mean = calc_mean(m=m,n=n,batch_size = batch_size)
    std = calc_std(m=m,n=n,mean=mean,batch_size = batch_size)
    return mean, std

def calc_mean(m,n,batch_size: int = 1000):
    mean = np.zeros((1, n))
    md_lists = []
    i  = 0
    total_docs = 0
    

        
    for doc in train_collection.find({}, {'metadata': 1,'_id': 0},batch_size = batch_size):
        md_lists.append(doc['metadata'])
        i += 1
        total_docs += 1
        if i >= batch_size:

            mean, md_lists = mean_aggregate(curr_md_list=md_lists, curr_mean=mean)
            i = 0
            
    if md_lists == []:
        mean, md_lists = mean_aggregate(curr_md_list=md_lists, curr_mean=mean)
        
    print(f"docs calcd: {total_docs}, m: {m}, n: {n}")
    mean = (1/m) * mean
    return mean
    

def calc_std(m,n,mean,batch_size: int = 1000):    
    std = np.zeros((1, n))
    md_lists = []
    i  = 0
    total_docs = 0
    
    for doc in train_collection.find({}, {'metadata': 1,'_id': 0},batch_size = batch_size):
        md_lists.append(doc['metadata'])
        i += 1
        total_docs += 1
        if i >= batch_size:
            std, md_lists = std_aggregate(curr_md_list = md_lists,curr_std = std,mean = mean)
            i = 0
            
    if md_lists == []:
        std, md_lists = std_aggregate(curr_md_list = md_lists,curr_std = std,mean = mean)
    print(f"docs calcd: {total_docs}, m: {m}, n: {n}")
    std = (1/m) * std
    std = np.sqrt(std)
    
    return std

def std_aggregate(curr_md_list,curr_std,mean):
    curr_md_list = np.array(curr_md_list)
    curr_md_list = (curr_md_list-mean) ** 2
    curr_md_list = np.sum(curr_md_list,axis = 0,keepdims=True)
    curr_std = curr_std + curr_md_list
    curr_md_list = []
    return curr_std, curr_md_list

def mean_aggregate(curr_md_list, curr_mean):
    curr_md_list = np.array(curr_md_list)
    curr_md_list = np.sum(curr_md_list,axis = 0,keepdims=True)
    curr_mean = curr_mean + curr_md_list
    curr_md_list = []
    return curr_mean, curr_md_list
    
train_mean,train_std = calc_mongo_train()    
       

docs calcd: 10530, m: 10530, n: 40
docs calcd: 10530, m: 10530, n: 40


In [18]:
print(train_std.shape)
print(train_mean.shape)

(1, 40)
(1, 40)


In [19]:
print(train_std)

[[ 1.67998454  0.77597053  0.75656886  0.60216061  0.44372332  1.62987469
   0.75083981  0.80289592  0.60150377  0.44529029  3.24852947  3.28131152
   0.45594077  0.47438099  0.48721459  0.48706524  0.45100037  0.44539655
  10.40961002  9.98440855  2.58493933  2.32523651  0.34842904  0.35381775
   0.0550385   0.          0.          0.          0.          0.
   0.          0.          1.60581062  1.53912994  0.32239083  0.33197067
   0.02922214  0.          0.48787261  0.48786866]]


# Come back to here later after implementing a larger dataset

In [43]:
documents = collection.find({}, {'metadata': 1, '_id': 0})

# Initialize a variable to store the first metadata list
sample = []
for document in documents:
    sample = document['metadata']
    break
print(sample)
# Get the length of the sample list
n = len(sample)

# Initialize a list to track zero values across all metadata lists
always_zero = [True] * n

# Reiterate through all documents to update the always_zero list
for document in documents:
    metadata = document['metadata']
    for i in range(n):
        if metadata[i] != 0:
            always_zero[i] = False

# Print the indexes of the values that are always zero
always_zero_indexes = [i for i, is_zero in enumerate(always_zero) if is_zero]
print("Indexes of values that are always zero:", always_zero_indexes)

[5, 0, 2, 0, 0, 5, 0, 2, 0, 0, 7, 7, 1, 1, 0, 0, 0, 0, 19, 15, 1.2666666666666666, 0.7894736842105263, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 1, 1, 0, 0, 1, 0]
Indexes of values that are always zero: [25, 26, 27, 28, 29, 30, 31, 37]


In [21]:
md_keys = ['white pawns', 'white knights', 'white bishops', 'white rooks', 'white queens', 'black pawns', 'black knights', 'black bishops', 'black rooks', 'black queens', 'total black pieces', 'total white pieces', 'white has bishop pair', 'black has bishop pair', 'white has knight bishop pair', 'black has knight bishop pair', 'white has knight pair', 'black has knight pair', 'white moves', 'black moves', 'white to black moves', 'white to white moves', 'Beginning Game', 'Middle Game', 'End Game', 'black can be drawn', 'black promote to queen', 'white can be drawn', 'white promote to queen', 'can be drawn', 'black queen can be taken', 'white queen can be taken', 'white attacks', 'black attacks', 'white can attack', 'black can attack', 'checkmate', 'stalemate', 'white turn', 'black turn']
md_vals = [5, 0, 2, 0, 0, 5, 0, 2, 0, 0, 7, 7, 1, 1, 0, 0, 0, 0, 19, 15, 1.2666666666666666, 0.7894736842105263, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 1, 1, 0, 0, 1, 0]

In [22]:
md_keys[25:31], md_keys[37]

(['black can be drawn',
  'black promote to queen',
  'white can be drawn',
  'white promote to queen',
  'can be drawn',
  'black queen can be taken'],
 'stalemate')

# Gets rid of zero std's

In [23]:
train_std = np.where(train_std == 0, 1e-10, train_std)

In [70]:
def bitboard_to_matrix(bitboard):
    return np.array([(bitboard >> shift) & 1 for shift in range(64)]).reshape(8, 8)
def create_cnn_input(bitboards):
    layers = []
    for bb in bitboards:  # Ensure consistent order
        matrix = bitboard_to_matrix(int(bb))
        # print(matrix)
        layers.append(matrix)
    return np.stack(layers)

In [71]:
bb= [551098122496,0,144115222435594240,0,0,2097152,9218374240305152,0,1152956688978935808,0,0,17179869184]
# tmp_dict = {'White Pawn': 65280,
#  'White Knight': 66,
#  'White Bishop': 36,
#  'White Rook': 129,
#  'White Queen': 8,
#  'White King': 16,
#  'Black Pawn': 71776119061217280,
#  'Black Knight': 4755801206503243776,
#  'Black Bishop': 2594073385365405696,
#  'Black Rook': 9295429630892703744,
#  'Black Queen': 576460752303423488,
#  'Black King': 1152921504606846976}

# bb = list(tmp_dict.values())

c = create_cnn_input(bb)
c

array([[[0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        

In [72]:
def iteratingFunctionScaled(collectionName,mongoUrl=mongo_url,dbName=db_name,batch_size: int = 1000):
    try:
        db = mongoConnection(mongoUrl,dbName)
        # Fetching the collection
        collection = db[collectionName]
        batch = collection.find({}, {'_id': 0}, batch_size = batch_size)

        for doc in batch:
            bitboards = []
            doc['positions_data'] = create_cnn_input(doc['positions_data'])
            doc['metadata'] = (doc['metadata'] - train_mean) / train_std
            doc['game_results'] = np.array(doc['game_results'])
            yield doc

    except Exception as e:
         print("Error occured!: ",e)

In [73]:
document_generator = iteratingFunctionScaled(mongoUrl=mongo_url,dbName=db_name,collectionName=main_collection)
try:
    sample = next(document_generator)
    # print("Sample Document Processed Data:", sample_data)
except StopIteration:
    print("No more documents in the collection.")
except Exception as e:
    print("An error occurred:", e)




In [74]:
class MongoDBDataset(Dataset):
    def __init__(self, collectionName, mongoUrl, dbName, batch_size=1):
        self.collectionName = collectionName
        self.mongoUrl = mongoUrl
        self.dbName = dbName
        self.batch_size = batch_size
        print(collectionName)
        print(mongoUrl)
        print(dbName)
        
        
        # Fetch data with a progress bar
        self.data = []
        try:
            for doc in tqdm(iteratingFunctionScaled(collectionName=collectionName,
                                                    mongoUrl=mongoUrl,
                                                    dbName=dbName, 
                                                    batch_size=batch_size), desc="Loading data"):
                self.data.append(doc)
            print(f"Loaded {len(self.data)} documents.")
        except Exception as e:
            print("Error during data loading:", e)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        try:
            if idx >= len(self.data):
                raise IndexError("Index out of range")
            
            doc = self.data[idx]
            positions_data = doc['positions_data']
            metadata = doc['metadata']
            game_results = doc['game_results']
            
            print(f"Returning item {idx} from dataset")
            
            return (torch.tensor(positions_data, dtype=torch.float32),
                    torch.tensor(metadata, dtype=torch.float32),
                    torch.tensor(game_results, dtype=torch.float32))
        except Exception as e:
            print(f"Error in __getitem__ at index {idx}: {e}")
            raise




In [75]:
document_generator = iteratingFunctionScaled(mongoUrl=mongo_url,dbName=db_name,collectionName=main_collection)

In [76]:
dataset = MongoDBDataset(collectionName=training_collection_key, mongoUrl=mongo_url, dbName=db_name, batch_size=1)

# Ensure that the dataset is not empty before creating the DataLoader
if len(dataset) > 0:
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
else:
    print("Dataset is empty. Please check the data loading process.")


# Debugging: Manually iterate through the dataset
print("Manually iterating through the dataset:")
for i in range(min(len(dataset), 5)):  # Check the first 5 items
    item = dataset[i]
    print(f"Item {i}: {item}")

# Debugging: Iterate through the DataLoader
print("Iterating through the DataLoader:")
for batch in dataloader:
    print("Batch loaded")
    batch_x1, batch_x2, batch_labels = batch
    print("Batch shapes:", batch_x1.shape, batch_x2.shape, batch_labels.shape)
    break  # Break after the first batch for debugging

# Debugging: Manually iterate through the dataset
print("Manually iterating through the dataset:")
for i in range(min(len(dataset), 5)):  # Check the first 5 items
    item = dataset[i]
    print(f"Item {i}: {item}")

# Debugging: Iterate through the DataLoader
print("Iterating through the DataLoader:")
for batch in dataloader:
    print("Batch loaded")
    batch_x1, batch_x2, batch_labels = batch
    print("Batch shapes:", batch_x1.shape, batch_x2.shape, batch_labels.shape)
    break  # Break after the first batch for debugging

training_data
mongodb://localhost:27017/
mydatabase


Loading data: 10530it [00:10, 998.29it/s] 

Loaded 10530 documents.
Manually iterating through the dataset:
Returning item 0 from dataset
Item 0: (tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0.,




In [83]:
class MultipleInputModel(nn.Module):
    def __init__(self, input_planes, additional_features=40):
        super(MultipleInputModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_planes, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        
        # Fully connected layer to integrate additional features
        self.fc_additional = nn.Linear(additional_features, 64 * 8 * 8)
        
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)
    
    def forward(self, x, additional_input):
        x = F.relu(self.conv1(x))
        
        # Integrate additional input
        additional_input = F.relu(self.fc_additional(additional_input))
        additional_input = additional_input.view(-1, 64, 8, 8)
        
        # Concatenate additional features with the output of the first convolutional layer
        x = torch.cat((x, additional_input), dim=1)
        
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Example usage:
n = 1  # Example value for n
additional_features = 40  # Size of the additional input feature vector
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultipleInputModel(n, additional_features).to(device)

# Example inputs
x1 = torch.randn(1, n, 8, 8).to(device)  # 8x8xn input
x2 = torch.randn(1, additional_features).to(device)  # nx1 input

output = model(x1, x2)
print(output)

RuntimeError: Given groups=1, weight of size [128, 64, 3, 3], expected input[1, 128, 8, 8] to have 64 channels, but got 128 channels instead

In [85]:

n = 1  # Example value for n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultipleInputModel(n).to(device)

# # Example inputs
# x1 = torch.randn(1, n, 8, 8)  # 8x8xn input
# x2 = torch.randn(1, n)        # nx1 input

# output = model(x1, x2)
# print(output)


# Example inputs
x1 = torch.randn(1, n, 8, 8).to(device)  # 8x8xn input
x2 = torch.randn(1, n).to(device)        # nx1 input

output = model(x1, x2)
print(output)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch_x1, batch_x2, batch_labels in progress_bar:

        # Move data to the appropriate device (CPU/GPU)
        batch_x1 = batch_x1.permute(0, 3, 1, 2).to(device)  # Change shape to (batch_size, n, 8, 8)
        batch_x2 = batch_x2.to(device)
        batch_labels = batch_labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(batch_x1, batch_x2)
        loss = criterion(outputs, batch_labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1 and 40x4096)