In [11]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

class Sandpile:
    def __init__(self, grid_size):
        self.grid_size = grid_size
        self.grid = np.zeros((grid_size, grid_size), dtype=int)
        self.t = 0
    
    def add_sand(self, x, y, grains=1):
        self.grid[x, y] += grains
        self.topple()
        self.t += 1

    def find_high_piles(self):
        x, y = np.where(self.grid > 3)
        return list(zip(x, y))

        
    def topple(self):
        unstable = True
        while unstable:
            unstable = False
            high_piles_coords = self.find_high_piles()
            for x, y in high_piles_coords:
                self.grid[x, y] -= 4
                if x > 0: self.grid[x-1, y] += 1
                if x < self.grid_size - 1: self.grid[x+1, y] += 1
                if y > 0: self.grid[x, y-1] += 1
                if y < self.grid_size - 1: self.grid[x, y+1] += 1
            self.t += 1
            if len(self.find_high_piles()) != 0:
                unstable = True                

    def get_grid(self):
        return self.grid
    
    def three_d_histogram(self, num_iterations, dir_name):
        os.makedirs(dir_name, exist_ok=True)
        for i in tqdm(range(num_iterations)):
            x = np.random.randint(0, self.grid_size)
            y = np.random.randint(0, self.grid_size)
            self.add_sand(x, y)

            # Create a 3D histogram of the current grid state
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')

            xpos, ypos = np.meshgrid(np.arange(self.grid_size), np.arange(self.grid_size))
            xpos = xpos.flatten()
            ypos = ypos.flatten()
            zpos = np.zeros_like(xpos)

            dx = dy = 1
            dz = self.grid.flatten()

            ax.bar3d(xpos, ypos, zpos, dx, dy, dz, cmap='viridis')

            plt.title('Sandpile Model Histogram')
            plt.savefig(os.path.join(dir_name, f'sandpile_histogram_{i}.png'))
            plt.close()  # Close the plot to avoid memory issues


In [16]:
# Example usage
grid_size = 10
num_grains = 100

sandpile = Sandpile(grid_size)

# Generate 100 random (x, y) coordinates
coordinates = np.random.randint(0, grid_size, size=(num_grains, 2))

for coord in tqdm(coordinates):
    sandpile.add_sand(coord[0], coord[1])
    
sandpile.three_d_histogram()

100%|██████████| 100/100 [00:00<00:00, 51168.77it/s]


TypeError: Sandpile.three_d_histogram() missing 2 required positional arguments: 'num_iterations' and 'dir_name'